前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【Pytorch】model.train() 和 model.eval() 原理与用法

【Pytorch】model.train() 和 model.eval() 原理与用法

作者头像
自学气象人
发布2023-09-06 14:23:49
9570
发布2023-09-06 14:23:49
举报
文章被收录于专栏:自学气象人自学气象人

一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train( ) 和 model.eval( )。

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。

二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train( ),作用是启用 batch normalization 和 dropout 。

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在训练时添加 model.train( )。

model.train( ) 是保证 BN 层能够用到每一批数据的均值和方差。对于 Dropout,model.train( ) 是随机取一部分网络连接来训练更新参数。

2. model.eval()

model.eval( )的作用是不启用 Batch Normalization 和 Dropout。

如果模型中有 BN 层(Batch Normalization)和 Dropout,在测试时添加 model.eval( )。

model.eval( ) 是保证 BN 层能够用全部训练数据的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval( ) 是利用到了所有网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval( ),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval( ) 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。

不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。eval( ) 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。也就是说,测试过程中使用model.eval( ),这时神经网络会沿用 batch normalization 的值,而不使用dropout。

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval( )。

其中 model.train( ) 是保证 BN 层用每一批数据的均值和方差,而 model.eval( ) 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train( ) 是随机取一部分网络连接来训练更新参数,而 model.eval( ) 是利用到了所有网络连接。

三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。

参考链接

[1]

PyTorch中train()方法的作用是什么: https://www.yisu.com/zixun/518049.html

[2]

【pytorch】model.train()和model.evel()的用法: https://blog.csdn.net/qq_37791134/article/details/108122202

[3]

pytorch中net.eval() 和net.train()的使用: https://www.jianshu.com/p/822d9ae0169d

[4]

Pytorch学习笔记11----model.train()与model.eval()的用法、Dropout原理、relu,sigmiod,tanh激活函数、nn.Linear浅析、输出整个tensor的方法: https://www.cnblogs.com/luckyplj/p/13424561.html

[5]

好文:Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别: https://zhuanlan.zhihu.com/p/357075502

来源:代码网

作者:qgyh

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-05-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 自学气象人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、两种模式
  • 二、功能
    • 1. model.train()
      • 2. model.eval()
        • 为什么测试时要用 model.eval() ?
          • 3. 总结与对比
          • 三、Dropout 简介
          • 参考链接
          相关产品与服务
          腾讯云服务器利旧
          云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档