Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >干货 | 如何理解深度学习分布式训练中的large batch size与learning rate的关系?

干货 | 如何理解深度学习分布式训练中的large batch size与learning rate的关系?

作者头像
AI科技评论
发布于 2018-03-14 06:42:57
发布于 2018-03-14 06:42:57
2.9K0
举报
文章被收录于专栏:AI科技评论AI科技评论

问题详情:

深度学习进行分布式训练时,常常采用同步数据并行的方式,也就是采用大的batch size进行训练,但large batch一般较于小的baseline的batch size性能更差,请问如何理解调试learning rate能使large batch达到small batch同样的收敛精度和速度?

回答:

最近在进行多GPU分布式训练时,也遇到了large batch与learning rate的理解调试问题,相比baseline的batch size,多机同步并行(之前有答案是介绍同步并行的通信框架NCCL 谭旭:如何理解Nvidia英伟达的Multi-GPU多卡通信框架NCCL?)等价于增大batch size,如果不进行精细的设计,large batch往往收敛效果会差于baseline的小batch size。因此将自己的理解以及实验总结如下,主要分为三个方面来介绍:(1)理解SGD、minibatch-SGD和GD,(2)large batch与learning rate的调试关系,(3)我们的实验。

(1)理解SGD、minibatch-SGD和GD

机器学习优化算法中,GD(gradient descent)是最常用的方法之一,简单来说就是在整个训练集中计算当前的梯度,选定一个步长进行更新。GD的优点是,基于整个数据集得到的梯度,梯度估计相对较准,更新过程更准确。但也有几个缺点,一个是当训练集较大时,GD的梯度计算较为耗时,二是现代深度学习网络的loss function往往是非凸的,基于凸优化理论的优化算法只能收敛到local minima,因此使用GD训练深度神经网络,最终收敛点很容易落在初始点附近的一个local minima,不太容易达到较好的收敛性能。

另一个极端是SGD(stochastic gradient descent),每次计算梯度只用一个样本,这样做的好处是计算快,而且很适合online-learning数据流式到达的场景,但缺点是单个sample产生的梯度估计往往很不准,所以得采用很小的learning rate,而且由于现代的计算框架CPU/GPU的多线程工作,单个sample往往很难占满CPU/GPU的使用率,导致计算资源浪费。

折中的方案就是mini-batch,一次采用batch size的sample来估计梯度,这样梯度估计相对于SGD更准,同时batch size能占满CPU/GPU的计算资源,又不像GD那样计算整个训练集。同时也由于mini batch能有适当的梯度噪声[8],一定程度上缓解GD直接掉进了初始点附近的local minima导致收敛不好的缺点,所以mini-batch的方法也最为常用。

关于增大batch size对于梯度估计准确度的影响,分析如下:

假设batch size为m,对于一个minibatch,loss为:

梯度

整个minibatch的梯度方差为:

由于每个样本

是随机从训练样本集sample得到的,满足i.i.d.假设,因此样本梯度的方差相等,为

等价于SGD的梯度方差,可以看到batch size增大m倍,相当于将梯度的方差减少m倍,因此梯度更加准确。

如果要保持方差和原来SGD一样,相当于给定了这么大的方差带宽容量,那么就可以增大lr,充分利用这个方差容量,在上式中添加lr,同时利用方差的变化公式,得到等式

因此可将lr增加sqrt(m)倍,以提高训练速度,这也是在linear scaling rule之前很多人常用的增大lr的方式[4]。下一小节将详细介绍增大lr的问题。

(2)large batch与learning rate

在分布式训练中,batch size 随着数据并行的worker增加而增大,假设baseline的batch size为B,learning rate为lr,训练epoch数为N。如果保持baseline的learning rate,一般不会有较好的收敛速度和精度。原因如下:对于收敛速度,假设k个worker,每次过的sample数量为kB,因此一个epoch下的更新次数为baseline的1/k,而每次更新的lr不变,所以要达到baseline相同的更新次数,则需要增加epoch数量,最大需要增加k*N个epoch,因此收敛加速倍数会远远低于k。对于收敛精度,由于增大了batch size使梯度估计相较于badeline的梯度更加准确,噪音减少,更容易收敛到附近的local minima,类似于GD的效果。

为了解决这个问题,一个方法就是增大lr,因为batch变大梯度估计更准,理应比baseline的梯度更确信一些,所以增大lr,利用更准确的梯度多走一点,提高收敛速度。同时增大lr,让每次走的幅度尽量大一些,如果遇到了sharp local minima[8](sharp minima的说法现在还有争议,暂且引用这个说法),还有可能逃出收敛到更好的地方。

但是lr不能无限制的增大,原因分析如下。深度神经网络的loss surface往往是高维高度非线性的,可以理解为loss surface表面凹凸不平,坑坑洼洼,不像y=x^2曲线这样光滑,因此基于当前weight计算出来的梯度,往前更新的learing rate很大的时候,沿着loss surface的切线就走了很大一步,有可能大大偏于原有的loss surface,示例如下图(a)所示,虚线是当前梯度的方向,也就是当前loss surface的切线方向,如果learning rate过大,那这一步沿切线方向就走了很大一步,如果一直持续这样,那很可能就走向了一个错误的loss surface,如图(b)所示。如果是较小的learning rate,每次只沿切线方向走一小步,虽然有些偏差,依然能大致沿着loss sourface steepest descent曲线向下降,最终收敛到一个不错的local minima,如图(c)所示。

同时也可以根据convex convergence theory[2]得到lr的upper bound:lr<1/L,L为loss surface的gradient curve的Lipschitz factor,L可以理解为loss梯度的变化幅度的上界。如果变化幅度越大,L越大,则lr就会越小,如果变化幅度越小,L越小,则lr就可以很大。这和上图的分析是一致的。

因此,如何确定large batch与learing rate的关系呢?

分别比较baseline和k个worker的large batch的更新公式[7],如下:

这个是baseline(batch size B)和large batch(batch size kB)的更新公式,(4)中large batch过一步的数据量相当于(3)中baseline k步过的数据量,loss和梯度都按找过的数据量取平均,因此,为了保证相同的数据量利用率,(4)中的learning rate应该为baseline的k倍,也就是learning rate的linear scale rule。

linear scale rule有几个约束,其中一个约束是关于weight的约束,式(3)中每一步更新基于的weight都是前一步更新过后的weight,因此相当于小碎步的走,每走一部都是基于目前真实的weight计算梯度做更新的,而式(4)的这一大步(相比baseline相当于k步)是基于t时刻的weight来做更新的。如果在这k步之内,W(t+j) ~ W(t)的话,两者近似没有太大问题,也就是linear scale rule问题不大,但在weight变化较快的时候,会有问题,尤其是模型在刚开始训练的时候,loss下特别快,weight变化很快,W(t+j) ~ W(t)就不满足。因此在初始训练阶段,一般不会直接将lr增大为k倍,而是从baseline的lr慢慢warmup到k倍,让linear scale rule不至于违背得那么明显,这也是facebook一小时训练imagenet的做法[7]。第二个约束是lr不能无限的放大,根据上面的分析,lr太大直接沿loss切线跑得太远,导致收敛出现问题。

同时,有文献[5]指出,当batchsize变大后,得到好的测试结果所能允许的lr范围在变小,也就是说,当batchsize很小时,比较容易找打一个合适的lr达到不错的结果,当batchsize变大后,可能需要精细地找一个合适的lr才能达到较好的结果,这也给实际的large batch分布式训练带来了困难。

(3)我们的实验

最近在考虑分布式训练NLP相关的深度模型的问题,实验细节如下,由于某些工作暂时还不方便透露,只提供较为简略的实验细节:

模型baseline参数为batch size 32, lr 0.25,最终的accuracy为BLEU score: 28.35。现在进行分布式扩展到多卡并行。

实验1:只增加并行worker数(也就相当于增大batch size),lr为baseline的lr0保持不变

可以看到随着batch的变大, 如果lr不变,模型的精度会逐渐下降,这也和上面的分析相符合。

实验2:增大batch size,lr相应增大

可以看到通过增加lr到5*lr0(理论上lr应该增加到8倍,但实际效果不好,因此只增加了5倍),并且通过warmup lr,达到和baseline差不多的Bleu效果。最终的收敛速度大约为5倍左右,也就是8卡能达到5倍的收敛加速(不考虑系统通信同步所消耗的时间,也就是不考虑系统加速比的情况下)。

深度学习并行训练能很好的提升模型训练速度,但是实际使用的过程中会面临一系列的问题,包括系统层面的架构设计、算法层面的参数调试等,欢迎有兴趣的朋友多多探讨。

[1] Li M, Zhang T, Chen Y, et al. Efficient mini-batch training for stochastic optimization[C]// Acm Sigkdd International Conference on Knowledge Discovery & Data Mining. ACM, 2014:661-670.

[2] Bottou L, Curtis F E, Nocedal J. Optimization Methods for Large-Scale Machine Learning[J]. 2016.

[3] Dekel O, Gilad-Bachrach R, Shamir O, et al. Optimal distributed online prediction using mini-batches[J]. Journal of Machine Learning Research, 2012, 13(1):165-202.

[4] Krizhevsky A. One weird trick for parallelizing convolutional neural networks[J]. Eprint Arxiv, 2014.

[5] Breuel T M. The Effects of Hyperparameters on SGD Training of Neural Networks[C]., 2015.

[6] Mishkin D, Sergievskiy N, Matas J. Systematic evaluation of CNN advances on the ImageNet[J]. 2016.

[7] Goyal, Priya, et al. "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour." arXiv preprint arXiv:1706.02677 (2017).

[8] Keskar N S, Mudigere D, Nocedal J, et al. On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima[J]. 2016.

[9]Scaling Distributed Machine Learning with System and Algorithm Co-design - Google Search

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2017-11-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI科技评论 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
解惑Java注解类型(待更新)理解Java注解基本语法注解与反射机制运行时注解处理器Java 8中注解增强
java注解是在JDK5时引入的新特性,鉴于目前大部分框架(如Spring)都使用了注解简化代码并提高编码的效率,因此掌握并深入理解注解对于一个Java工程师是来说是很有必要的事。本篇我们将通过以下几个角度来分析注解的相关知识点
JavaEdge
2018/10/11
1.9K0
Java 注解入门 自动生成SQL语句
在用hibernate的时候发现idea能自动生成JavaBean,同时带有一些注解,这引起了我的好奇。当在学习Android的时候,我发现XUtils这个工具包中的DBUtils也能够使用类似hibernate的注解。于是乎在java编程思想中找了找有关注解的用法。
zhangheng
2020/04/29
1.4K0
Java 注解 学习笔记
我们平常写Java代码,对其中的注解并不是很陌生,比如说写继承关系的时候经常用到@Override来修饰方法。但是@Override是用来做什么的,为什么写继承方法的时候要加上它,不加行不行。如果对Java的注解没有了解过,很难回答这些问题。并且,现在越来越多的第三方库开始使用注解,不了解注解的话很难理解他们的逻辑。趁着五一假期,赶紧补习一下什么是注解。
yuxiaofei93
2018/09/11
5840
java注意事项演示 地图产生表 演示样本 来自thinking in java 4 20代码的章
java注意事项演示 地图产生表 演示样本 来自thinking in java 4 20代码的章
全栈程序员站长
2022/07/06
3310
Java 注解机制
注解是 JDK1.5版本开始引入的一个特性,用于对代码进行说明,可以对包、类、接口、字段、方法参数、局部变量等进行注解。它主要的作用有以下四方面: 【1】生成文档:通过代码里标识的元数据生成 javadoc文档。 【2】编译检查:通过代码里标识的元数据让编译器在编译期间进行检查验证。 【3】编译时动态处理:编译时通过代码里标识的元数据动态处理,例如动态生成代码。 【4】运行时动态处理:运行时通过代码里标识的元数据动态处理,例如使用反射注入实例。
Java架构师必看
2021/05/14
6470
【Java 基础 - 注解机制详细解释】
注解是JDK1.5版本开始引入的一个特性,用于对代码进行说明,可以对包、类、接口、字段、方法参数、局部变量等进行注解。它主要的作用有以下四方面:
奥耶可乐冰
2024/05/31
1590
java自定义注解的使用和基本原理「建议收藏」
在web开发中,权限控制非常重要,所以有些接口会限制必须登录之后才能访问,但是个别接口并没有这种限制。一种方式是把需要过滤的接口或者方法配置在文件中,每次请求时在拦截器中根据请求的路径与配置文件中的对比过滤。其实还有另外一种方式就是通过注解方式。
全栈程序员站长
2022/07/30
5180
java自定义注解的使用和基本原理「建议收藏」
Java 注解 —— 注解的理解、注解的使用与自定义注解
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ajianyingxiaoqinghan/article/details/81436118
剑影啸清寒
2019/05/26
3.3K0
java 注解简述
注解(annotation)相当于一个运行于内存当中的自定义类型的数据存储区域,理解以后才发现它的好用,就是数据存储区,相当于一个运行在内存当中的XML,所有的注解数据在JDK加载完类以后,就可以被使用。
潇洒
2023/10/20
2070
Java-Java5.0注解解读
Java5.0注解可以看做Javadoc和Xdoclet标签的延伸和发展,在Java5.0中可以自定义这些标签,并通过Java语言的反射机制获取类中标注的注解,完成特定的功能。
小小工匠
2021/08/16
2610
Java 注解 Annotation 详解
注解(Annotation)就是 Java 提供了一种元程序中的元素关联任何信息和着任何元数据(metadata)的途径和方法。Annotation 是一个接口,程序可以通过反射来获取指定程序元素的 Annotation 对象,然后通过 Annotation 对象来获取注解里面的元数据。
BUG弄潮儿
2021/09/10
1.4K0
Java 注解 Annotation 详解
基础篇:深入解析JAVA注解机制
在代码里定义的注解,会被jvm利用反射技术生成一个代理类,然后和被注释的代码(类,方法,属性等)关联起来
潜行前行
2020/12/11
6700
基础篇:深入解析JAVA注解机制
Java编译时注解自动生成代码[通俗易懂]
在开始之前,我们首先申明一个非常重要的问题:我们并不讨论那些在运行时(Runtime)通过反射机制运行处理的注解,而是讨论在编译时(Compile time)处理的注解。注解处理器是一个在javac中的,用来编译时扫描和处理的注解的工具。可以为特定的注解,注册自己的注解处理器。
全栈程序员站长
2022/09/01
2.9K0
spring自定义注解实现(spring里面的注解)
1.SOURCE:在源文件中生效,仅存在java文件中,class文件将会去除注解。
全栈程序员站长
2022/07/30
8220
spring自定义注解实现(spring里面的注解)
Java注解Annotation使用
Java注解Annotation的使用 RuntimeAnnotation注解 package annotation; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; @Retention(RetentionPolicy.RUNT
用户9854323
2022/06/25
2000
TIII-Android技术篇之注解Annotation
开幕:初见 首先看一下家喻户晓的@Override注解:添加此注解,如果是非覆写的方法,就会报错 @Target(ElementType.METHOD) @Retention(RetentionPolicy.SOURCE) public @interface Override { } 再先看一下@Deprecated注解:添加此注解,如果是过时的方法,就会画线提示 @Documented @Retention(RetentionPolicy.RUNTIME) @Target(value={CONSTRUCT
张风捷特烈
2018/09/29
4710
TIII-Android技术篇之注解Annotation
Retrofit解析4之注解
由于Retrofit里面大量的用到了注解,为了让大家更好的学习Retrofit,特意准备了一篇Java注解,如果大家已经对Java注解已经很熟悉了,就略过,看下一篇文章 本篇文章主要讲解
隔壁老李头
2018/08/30
1.4K0
Retrofit解析4之注解
Java 注解与单元测试
Java注解是在JDK1.5 之后出现的新特性,用来说明程序的,注解的主要作用体现在以下几个方面:
Masimaro
2019/09/02
1.2K0
Java中的注解处理器是什么,提供一个自定义注解处理器的实际案例
Java中的注解处理器(Annotation Processor)是一种在编译时期处理注解的工具,它可以通过扫描和解析源代码中的注解信息,生成额外的代码、配置文件或者进行其他特定的处理操作。注解处理器能够帮助开发者实现自定义的代码生成、静态分析、验证等功能,从而提高开发效率和代码质量。
用户1289394
2024/06/11
2450
Java中的注解处理器是什么,提供一个自定义注解处理器的实际案例
你分析过注解 Annotation 的实现原理吗?
对于很多初次接触的开发者来说应该都有这个疑问?Annontation是Java5开始引入的新特征,中文名称叫注解。它提供了一种安全的类似注释的机制,用来将任何的信息或元数据(metadata)与程序元素(类、方法、成员变量等)进行关联。为程序的元素(类、方法、成员变量)加上更直观更明了的说明,这些说明信息是与程序的业务逻辑无关,并且供指定的工具或框架使用。
JavaFish
2019/10/17
7K0
相关推荐
解惑Java注解类型(待更新)理解Java注解基本语法注解与反射机制运行时注解处理器Java 8中注解增强
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档