Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >深度学习里面,请问有写train函数的模板吗?

深度学习里面,请问有写train函数的模板吗?

作者头像
lyhue1991
发布于 2023-02-23 05:12:25
发布于 2023-02-23 05:12:25
1.2K00
代码可运行
举报
运行总次数:0
代码可运行

知乎热门问题:深度学习里面,请问有写train函数的模板吗?

以下是 知乎用户 吃货本货 的回答。

老师,这题我会。

一般pytorch需要用户自定义训练循环,可以说有1000个pytorch用户就有1000种训练代码风格。 从实用角度讲,一个优秀的训练循环应当具备以下特点。

  • 代码简洁易懂 【模块化、易修改、short-enough】
  • 支持常用功能 【进度条、评估指标、early-stopping】

经过反复斟酌测试,我精心设计了仿照keras风格的pytorch训练循环。诸君且看。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import os,sys,time
import numpy as np
import pandas as pd
import datetime 
from tqdm import tqdm 

import torch
from torch import nn 
from copy import deepcopy

def printlog(info):
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("\n"+"=========="*8 + "%s"%nowtime)
    print(str(info)+"\n")

class StepRunner:
    def __init__(self, net, loss_fn,
                 stage = "train", metrics_dict = None, 
                 optimizer = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer = optimizer

    def step(self, features, labels):
        #loss
        preds = self.net(features)
        loss = self.loss_fn(preds,labels)

        #backward()
        if self.optimizer is not None and self.stage=="train": 
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

        #metrics
        step_metrics = {self.stage+"_"+name:metric_fn(preds, labels).item() 
                        for name,metric_fn in self.metrics_dict.items()}
        return loss.item(),step_metrics

    def train_step(self,features,labels):
        self.net.train() #训练模式, dropout层发生作用
        return self.step(features,labels)

    @torch.no_grad()
    def eval_step(self,features,labels):
        self.net.eval() #预测模式, dropout层不发生作用
        return self.step(features,labels)

    def __call__(self,features,labels):
        if self.stage=="train":
            return self.train_step(features,labels) 
        else:
            return self.eval_step(features,labels)

class EpochRunner:
    def __init__(self,steprunner):
        self.steprunner = steprunner
        self.stage = steprunner.stage

    def __call__(self,dataloader):
        total_loss,step = 0,0
        loop = tqdm(enumerate(dataloader), total =len(dataloader))
        for i, batch in loop: 
            loss, step_metrics = self.steprunner(*batch)
            step_log = dict({self.stage+"_loss":loss},**step_metrics)
            total_loss += loss
            step+=1
            if i!=len(dataloader)-1:
                loop.set_postfix(**step_log)
            else:
                epoch_loss = total_loss/step
                epoch_metrics = {self.stage+"_"+name:metric_fn.compute().item() 
                                 for name,metric_fn in self.steprunner.metrics_dict.items()}
                epoch_log = dict({self.stage+"_loss":epoch_loss},**epoch_metrics)
                loop.set_postfix(**epoch_log)

                for name,metric_fn in self.steprunner.metrics_dict.items():
                    metric_fn.reset()
        return epoch_log


def train_model(net, optimizer, loss_fn, metrics_dict, 
                train_data, val_data=None, 
                epochs=10, ckpt_path='checkpoint.pt',
                patience=5, monitor="val_loss", mode="min"):

    history = {}

    for epoch in range(1, epochs+1):
        printlog("Epoch {0} / {1}".format(epoch, epochs))

        # 1,train -------------------------------------------------  
        train_step_runner = StepRunner(net = net,stage="train",
                loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict),
                optimizer = optimizer)
        train_epoch_runner = EpochRunner(train_step_runner)
        train_metrics = train_epoch_runner(train_data)

        for name, metric in train_metrics.items():
            history[name] = history.get(name, []) + [metric]

        # 2,validate -------------------------------------------------
        if val_data:
            val_step_runner = StepRunner(net = net,stage="val",
                loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict))
            val_epoch_runner = EpochRunner(val_step_runner)
            with torch.no_grad():
                val_metrics = val_epoch_runner(val_data)
            val_metrics["epoch"] = epoch
            for name, metric in val_metrics.items():
                history[name] = history.get(name, []) + [metric]

        # 3,early-stopping -------------------------------------------------
        arr_scores = history[monitor]
        best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)
        if best_score_idx==len(arr_scores)-1:
            torch.save(net.state_dict(),ckpt_path)
            print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor,
                 arr_scores[best_score_idx]),file=sys.stderr)
        if len(arr_scores)-best_score_idx>patience:
            print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
                monitor,patience),file=sys.stderr)
            break 
        net.load_state_dict(torch.load(ckpt_path))

    return pd.DataFrame(history)

使用方法如下:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from torchmetrics import Accuracy

loss_fn = nn.BCEWithLogitsLoss()
optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)   
metrics_dict = {"acc":Accuracy()}

dfhistory = train_model(net,
    optimizer,
    loss_fn,
    metrics_dict,
    train_data = dl_train,
    val_data= dl_val,
    epochs=10,
    patience=5,
    monitor="val_acc", 
    mode="max")

疗效如下:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
================================================================================2022-07-10 20:06:16
Epoch 1 / 10

100%|██████████| 200/200 [00:17<00:00, 11.74it/s, train_acc=0.735, train_loss=0.53]
100%|██████████| 40/40 [00:01<00:00, 20.07it/s, val_acc=0.827, val_loss=0.383]
<<<<<< reach best val_acc : 0.8274999856948853 >>>>>>

================================================================================2022-07-10 20:06:35
Epoch 2 / 10

100%|██████████| 200/200 [00:16<00:00, 11.96it/s, train_acc=0.832, train_loss=0.391]
100%|██████████| 40/40 [00:02<00:00, 18.13it/s, val_acc=0.854, val_loss=0.317]
<<<<<< reach best val_acc : 0.8544999957084656 >>>>>>

================================================================================2022-07-10 20:06:54
Epoch 3 / 10

100%|██████████| 200/200 [00:17<00:00, 11.71it/s, train_acc=0.87, train_loss=0.313]
100%|██████████| 40/40 [00:02<00:00, 19.96it/s, val_acc=0.902, val_loss=0.239]
<<<<<< reach best val_acc : 0.9024999737739563 >>>>>>

================================================================================2022-07-10 20:07:13
Epoch 4 / 10

100%|██████████| 200/200 [00:16<00:00, 11.88it/s, train_acc=0.889, train_loss=0.265]
100%|██████████| 40/40 [00:02<00:00, 18.46it/s, val_acc=0.91, val_loss=0.216]
<<<<<< reach best val_acc : 0.9100000262260437 >>>>>>

================================================================================2022-07-10 20:07:32
Epoch 5 / 10

100%|██████████| 200/200 [00:17<00:00, 11.71it/s, train_acc=0.902, train_loss=0.239]
100%|██████████| 40/40 [00:02<00:00, 19.68it/s, val_acc=0.891, val_loss=0.279]

================================================================================2022-07-10 20:07:51
Epoch 6 / 10

100%|██████████| 200/200 [00:17<00:00, 11.75it/s, train_acc=0.915, train_loss=0.212]
100%|██████████| 40/40 [00:02<00:00, 19.52it/s, val_acc=0.908, val_loss=0.222]

================================================================================2022-07-10 20:08:10
Epoch 7 / 10

100%|██████████| 200/200 [00:16<00:00, 11.79it/s, train_acc=0.921, train_loss=0.196]
100%|██████████| 40/40 [00:02<00:00, 19.26it/s, val_acc=0.929, val_loss=0.187]
<<<<<< reach best val_acc : 0.9294999837875366 >>>>>>

================================================================================2022-07-10 20:08:29
Epoch 8 / 10

100%|██████████| 200/200 [00:17<00:00, 11.59it/s, train_acc=0.931, train_loss=0.175]
100%|██████████| 40/40 [00:02<00:00, 19.91it/s, val_acc=0.938, val_loss=0.187]
<<<<<< reach best val_acc : 0.9375 >>>>>>

================================================================================2022-07-10 20:08:49
Epoch 9 / 10

100%|██████████| 200/200 [00:17<00:00, 11.68it/s, train_acc=0.929, train_loss=0.178]
100%|██████████| 40/40 [00:02<00:00, 19.90it/s, val_acc=0.937, val_loss=0.181]

================================================================================2022-07-10 20:09:08
Epoch 10 / 10

100%|██████████| 200/200 [00:16<00:00, 11.84it/s, train_acc=0.937, train_loss=0.16] 
100%|██████████| 40/40 [00:02<00:00, 19.91it/s, val_acc=0.937, val_loss=0.167]

该训练循环满足我所说的以上全部这些特性。

  • 1,模块化:自下而上分成 StepRunner, EpochRunner, 和train_model 三级,结构清晰明了。
  • 2,易修改:如果输入和label形式有差异(例如,输入可能组装成字典,或者有多个输入),仅需更改StepRunner就可以了,后面无需改动,非常灵活。
  • 3,short-enough: 全部训练代码不到150行。
  • 4,支持进度条:通过tqdm引入。
  • 5,支持评估指标:引入torchmetrics库中的指标。
  • 6,支持early-stopping:在train_model函数中指定 monitor、mode、patience即可。

以上训练循环也是我在eat_pytorch_in_20_days中使用的主要训练循环。该库目前已经获得3.3k+星星⭐️,大部分读者反馈还是挺好用的。

点击文末阅读原文,查看知乎原始回答,感觉不错的小伙伴可以给吃货本货一个赞同表示鼓励哦,谢谢大家。😊 逃~

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-08-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法美食屋 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
深度学习里面有没有支持Multi-GPU-DDP模式的pytorch模型训练代码模版?
一般pytorch需要用户自定义训练循环,可以说有1000个pytorch用户就有1000种训练代码风格。
lyhue1991
2023/02/23
7060
深度学习里面有没有支持Multi-GPU-DDP模式的pytorch模型训练代码模版?
60分钟吃掉FM算法
参考文章:张俊林《FFM及DeepFFM模型在推荐系统的探索》https://zhuanlan.zhihu.com/p/67795161
lyhue1991
2022/09/01
6600
60分钟吃掉FM算法
Kaggle免费GPU使用攻略
在国内使用邮箱注册kaggle时会遇到一个人机验证的步骤,可以通过翻墙访问外网的方式完成,但比较麻烦。
lyhue1991
2023/02/23
5K0
Kaggle免费GPU使用攻略
60分钟吃掉三杀模型FiBiNET
新浪微博广告推荐技术团队2019年发布的CTR预估模型FiBiNET同时巧妙地运用了以上3种技巧,是神经网络结构设计的教科书级的范例。
lyhue1991
2022/09/01
1.4K0
60分钟吃掉三杀模型FiBiNET
炼丹5至7倍速,使用Mac M1 芯片加速pytorch完全指南
2022年5月,PyTorch官方宣布已正式支持在M1芯片版本的Mac上进行模型加速。官方对比数据显示,和CPU相比,M1上炼丹速度平均可加速7倍。
lyhue1991
2023/02/23
16.3K2
炼丹5至7倍速,使用Mac M1 芯片加速pytorch完全指南
YOLOv8 训练自己的数据集
本范例我们使用 ultralytics中的YOLOv8目标检测模型训练自己的数据集,从而能够检测气球。
lyhue1991
2023/09/05
3.4K1
YOLOv8 训练自己的数据集
120分钟吃掉DIEN深度兴趣演化网络
2018年的深度兴趣演化网络, DIEN(DeepInterestEvolutionNetWork)。
lyhue1991
2023/02/23
4810
120分钟吃掉DIEN深度兴趣演化网络
使用BERT进行文本分类
准备数据阶段主要需要用到的是datasets.Dataset 和transformers.AutoTokenizer。
lyhue1991
2023/09/05
7330
使用BERT进行文本分类
30分钟吃掉CRNN-CTC验证码识别
项目参考:https://github.com/ypwhs/captcha_break
lyhue1991
2023/09/05
3510
30分钟吃掉CRNN-CTC验证码识别
AlexNet代码详解
AlexNet由Hinton和他的学生Alex Krizhevsky设计,模型名字来源于论文第一作者的姓名Alex。该模型以很大的优势获得了2012年ISLVRC竞赛的冠军网络,分类准确率由传统的 70%+提升到 80%+,自那年之后,深度学习开始迅速发展。
OliverHan
2023/04/23
8540
30分钟吃掉DQN算法
表格型方法存储的状态数量有限,当面对围棋或机器人控制这类有数不清的状态的环境时,表格型方法在存储和查找效率上都受局限,DQN的提出解决了这一局限,使用神经网络来近似替代Q表格。
lyhue1991
2023/09/05
2910
30分钟吃掉DQN算法
Q-learning解决悬崖问题
Q-learning是一个经典的强化学习算法,是一种基于价值(Value-based)的算法,通过维护和更新一个价值表格(Q表格)进行学习和预测。
lyhue1991
2023/09/05
3810
Q-learning解决悬崖问题
第一个深度学习实战案例:电影评论分类
这是一个典型的二分类问题。使用的是IMDB数据集,训练集是25000条,测试也是25000条
皮大大
2022/04/02
5220
第一个深度学习实战案例:电影评论分类
30分钟吃掉YOLOv8实例分割范例
本范例我们使用 torchkeras来实现对 ultralytics中的YOLOv8实例分割模型进行自定义的训练,从而对气球进行检测和分割。
lyhue1991
2023/09/17
2.6K1
30分钟吃掉YOLOv8实例分割范例
用BERT做命名实体识别任务
本质上NER是一个token classification任务, 需要把文本中的每一个token做一个分类。
lyhue1991
2023/09/05
7620
用BERT做命名实体识别任务
使用 PyTorch Geometric 在 Cora 数据集上训练图卷积网络GCN
图结构在现实世界中随处可见。道路、社交网络、分子结构都可以使用图来表示。图是我们拥有的最重要的数据结构之一。
deephub
2021/12/28
2.1K0
使用 PyTorch Geometric 在 Cora 数据集上训练图卷积网络GCN
深度学习框架Keras深入理解
Python深度学习-深入理解Keras:Keras标准工作流程、回调函数使用、自定义训练循环和评估循环。
皮大大
2023/09/06
4920
60分钟吃掉ChatGLM2-6b微调范例~
干货预警:这可能是你能够找到的最容易懂的,最完整的,适用于各种NLP任务的开源LLM的finetune教程~
lyhue1991
2023/09/05
8090
60分钟吃掉ChatGLM2-6b微调范例~
卷积网络与全连接网络比较分析
我们通过对比全连接和卷积的学习过程最后的精确度等因素,发现卷积比全连接神经网络更适合做图像处理,在这个过程中,全连接模型中会有很多参数,这对于图像的要求太高,如果图像出现变动,会导致模型改动较大。
算法与编程之美
2023/08/22
2020
卷积网络与全连接网络比较分析
Tensorflow搭建CNN实现验证码识别
采用三层卷积,filter_size均为5,为避免过拟合,每层卷积后面均接dropout操作,最终将
Awesome_Tang
2019/01/28
8610
Tensorflow搭建CNN实现验证码识别
相关推荐
深度学习里面有没有支持Multi-GPU-DDP模式的pytorch模型训练代码模版?
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验