Meta Learning的应用MAML
Meta Learnig,元学习,就是能够让机器学习如何去学习(Learning to Learning)。MAML即模型无关,绝大多数深度学习模型都可以作为base-learner无缝嵌入,甚至可以用于强化学习。MAML是学习一个最好的初始化的方法,训练出合适的模型初始化参数,使得在样本量很少的情况下快速收敛。该网络的目标是训练一个模型, 如果给定一个新任务的一步梯度更新, 那么它便可以很好地在该任务泛化。
论文:《MAML: ModelAgnostic Meta Learning》 论文地址:https://arxiv.org/pdf/1703.03400.pdf github:httpshttps://github.com/dragen1860/MAML-Pytorch
将数据集分成meta-train和meta-test两部分,meta-test测试模型的收敛速度(用Dtrain训练,用Dtest测试分类效果),meta-train用于训练模型(Dtrain和Dtest一起训练模型),下图中每一个横条称为一个task。

对于meta-train数据集,首先整理成上图中的形式(从100个类中随机抽取5个类,从5个类的数据集中分别随机抽1个样例作为Dtrain,从5个类数据集中随机抽3个样例作为Dtest),这样一来就会形成很多个task,这些task就是训练集(一个task相当于传统机器学习中的一个样例),多个task构成一个batch。算法如下:

首先初始化模型参数,之后开始对训练集和验证集的task样本随机采样,每个task是一组样本,包含N个类,每类包括K个训练样本和K’个测试样本。针对每个task的meta training set 计算梯度gradient descent,得到

为该task的期望参数

,但不反向传播去更新真实参数Theta。用meta testing set在期望参数来验证,得到一个meta loss

,遍历所有task之后得到

作为最后的meta loss。
MAML希望通过梯度更新,找到对task敏感的theta,使得模型具有对新task的学习分布最敏感的参数,在一次或多次梯度更新中获得适应新任务的theta*。

MAML采取的学习方式是优化参数在各个任务上的梯度方向矢量和,具体原理则是元任务adapt一步(或几步)提供的二阶导信息,也就是curvature of tasks,这样的高阶导信息可以为模型的初始化提供方向信息,也就是我们所用的每个task的梯度方向。优化分为两层: inner loop和outer loop。Inner loop就是training procedure,对于每个任务学习处理这个任务的基本能力。outer loop就是meta training procedure,学习多个任务的泛化能力。

注意这里的Meta Learning和Pretrain model是不一样的,MAML是在所有任务上训练得到新的模型参数,Pretrain model是采用一批数据训练出一个模型,这样的模型参数直接用在下一批新的数据中,


MAML最终的参数在task1和task2都不是最好的表现,但对于两个任务来说是一个相对比较好的结果,但可以在task1中经过训练得到theta1,在task2中经过训练得到theta2。

在Pretrain model中,比较看重的是当前的表现,参数fei在task1中表现不错,经过训练后得到theta在task2中表现不错,而MAML更看重未来的表现。

在MAML中,我们希望每个任务的参数更新只需要一次就可以得到很好的参数表现,这样对于多任务的训练速度很快。
参考: https://www.jianshu.com/p/692604df9cfb https://zhuanlan.zhihu.com/p/68555964?app=zhihulite https://www.zhihu.com/question/266497742/answer/550695031

本文分享自 Python编程和深度学习 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!