Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >PyTorch 零基础入门 GAN 模型之 cGAN

PyTorch 零基础入门 GAN 模型之 cGAN

作者头像
OpenMMLab 官方账号
修改于 2022-05-12 08:37:01
修改于 2022-05-12 08:37:01
1.4K00
代码可运行
举报
文章被收录于专栏:OpenMMLabOpenMMLab
运行总次数:0
代码可运行

背景介绍

在之前的文章中,我们介绍了 GAN 的原理以及如何评价训练好的模型。可能有小伙伴看到,怎么生成的都是单一类别的图片呢,像 CIFAR10 和 ImageNet,都包含了多种类别的图片,如果我想训练一个能够生成多种类别图片的生成对抗网络该怎么做呢?

那么称为 conditional GANs (cGAN) 的模型就可以派上用场了。

cGAN 的发展

如何引入类别信息

最早提出 cGAN 的是论文 《Conditional Generative Adversarial Nets》,为了达到条件生成的目的,我们在输入给生成器网络 G 的噪声 z 上 concat 一个标签向量 y, 告诉生成网络生成标签所指定的数据。对于输入给判别器 D 的数据,也 concat 这样的一个标签,告诉判别网络判断输入是否为真实的该类别数据。

那么,cGAN 的目标函数可以表述成如下形式:

cGAN 采用 MLP 作为网络结构,一维的输入可以方便地和标签向量或者标签嵌入 concat,但是对于图像生成任务主流的 CNN 模型,无法直接采用这种引入方式,特别是对于判别器网络。

Projection GAN 通过推导发现,假设 ,其中 为判别器网络, 为激活函数, 为网络模型,在上面的目标函数下,最优的 为

根据上式,输入图像先经过网络 抽取一维特征,然后分两路,一路通过网络 输出关于图像真假的判别结果,一路与经过编码的类别标签做点乘得到关于类别的判别结果,之后将两路结果相加就可以得到最终的判别结果。模型结构如下:

如何稳定训练过程

GAN 的训练过程不稳定,而 cGAN 中学习多种类别的数据,稳定训练过程则更具有挑战性。在之前的研究如 WGAN,WGANGP 中,学者们发现对网络施加 Lipschitz 约束能有效稳定 GAN 的训练过程。

相比于之前在损失函数中增加正则项的做法,SNGAN 提出了谱归一化 (spectral normalization) 来构造网络模型,使得无论网络参数是什么,都能满足 Lipschitz 约束。

如何提升生成质量

在多类别数据如 ImageNet 的训练过程中,人们发现网络更加擅长生成局部的细节纹理,如狗的毛发;而对于几何特征和整体结构,生成效果往往不尽如人意。如下图 SNGAN 生成结果为例,狗的身体结构存在很多错误和不完整的表达。

SAGAN 认为这是由于卷积模型难以捕捉到距离较远的特征,因此引入注意力机制,设计了如下的注意力模块。

这里卷积特征图经过三个 1x1 卷积 f(x), g(x), h(x),将 f(x) 的输出转置,并和 g(x) 的输出相乘,再经过 softmax 归一化得到一个 attention map ,将得到的 attention map 和 h(x) 逐像素点相乘,再经过卷积 v(x) 得到自适应注意力的特征图。

SAGAN 在生成器和判别器中引入了这个注意力模块,使得生成器可以建模图像跨区域的依赖关系,判别器可以对全局图像结构施加几何约束。

集大成者:BigGAN

在采用上面提到的投影判别,谱归一化,自注意模块的基础上,《LARGE SCALE GAN TRAINING FOR HIGH FIDELITY NATURAL IMAGE SYNTHESIS》(即 BigGAN )通过增大 batch size,提升模型宽度,显著提升了图像生成的结果。(见下图,我们在上篇文章中提到了 FID 和 IS 两个评价生成模型的 metrics,可以看到,通过提升 batch size 和 通道数,这两个 metric 有显著提升。)

为了让 noise 能够直接影响不同分辨率上的特征,BigGAN 提出了 skip-z ,将 noise 分割后,分别传入 G 网络不同尺度的 layer。为了节省时间和内存,BigGAN 只对 class 做一次 embedding,将编码结果传给每个条件批归一化。以下为一个 BigGAN 的 G 网络结构图。

其中每一个 ResBlock 结构如下。分组的噪声和类别嵌入 concat 后,通过线性层得到每个 BatchNorm 的 gain 和 bias 参数,通过这种方式引入类别信息。

使用 MMGeneration 上手 BigGAN

在我们的文章 PyTorch 零基础入门 GAN 模型之基础篇 中,我们介绍了如何安装 MMGen 和训练模型。在此基础上,我们可以上手以 BigGAN 为代表的条件生成模型。

我们可以先看看 BigGAN 生成的图片长啥样,通过运行如下代码,我们可以从预训练好的 BigGAN 中 sample 类别随机的图片。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth

当然,我们也可以用 --label 来指定采样的类别,--samples-per-classes 来指定每类采样的数量。

比如运行下面代码:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
python demo/conditional_demo.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --label 151 285 292 --samples-per-classes 5

可以得到 狗(151),猫(285),老虎(292)各五张图片。

我们还可以通过运行下面的代码看看 BigGAN 分别从噪声空间和标签空间插值的结果。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-z # 固定噪声

python apps/conditional_interpolate.py configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py https://download.openmmlab.com/mmgen/biggan/biggan_imagenet1k_128x128_b32x8_best_fid_iter_1232000_20211111_122548-5315b13d.pth --samples-path work_dirs/demos/ --show-mode group --fix-y # 固定标签

结果分别如下:

现在我们可以看看怎么训练 BigGAN,首先我们需要下载 ImageNet,然后放到 ./data 文件夹下。ImageNet 的下载方法可以参考这篇知乎分享 https://zhuanlan.zhihu.com/p/42696535。

BigGAN 的训练有几个关键点设置,我们以 configs/_base_/models/biggan/biggan_128x128.py 和configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 为例进行说明。

首先是 spectral normalization 的实现方式,我们提供了两种实现方式,一种是 PyTorch 官方提供的实现,另一种是 BigGAN 作者 ajbrock 提供的实现。我们可以设置 sn_style 为 torch 或者 ajbrock 来选择,如果不设置,默认为 ajbrock。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 使用 BigGAN 作者提供的 SN 实现
model = dict(
    type='BasiccGAN',
    generator=dict(xxx, sn_style='ajbrock'),
    discriminator=dict(xxx, sn_style='ajbrock'),
    gan_loss=dict(type='GANLoss', gan_type='hinge'))

在生成模型中,有时希望用一个更强的 D 来引导 G 的更新。常用的一种设置为训练若干步判别器,再训练一步生成器 ,这里 BigGAN 通过设置 train_cfg 中的 disc_steps 和 gen_steps 来实现。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
train_cfg = dict(
    disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True)

这个 config 表示训练 8 步判别器,再训练一步生成器。

在上面代码中,字段 batch_accumulation_steps 涉及到梯度累积操作,因为显存限制,我们很难直接在 batchsize 为 2048 的数据上做 forward 和 backward,因此可以将多个小批量上的梯度平均化,只在 batch_accumulation_steps 次累积后进行优化。假设我们在 8 卡上训练,每张卡 batchsize 为 32,batch_accumulation_steps = 8,这样可以逼近 8*8*32 = 2048 的 batchsize。

为了稳定 GAN 的训练过程,我们往往要使用一种叫指数移动平均的技巧。通过设置 generator 网络的备份 generator_ema 。在每个 train iter 后,将更新的模型参数和历史参数加权平均后作为generator_ema 的参数,这样 generator_ema 的参数更新会比 generator 更加平滑,作为训练结束后的 inference 模型,其生成结果更好。

为了使用 ema,首先需要在上文 train_cfg 中将 use_ema 设置为 True,同时,需要在 config 中添加一个 ExponentialMovingAverageHook。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 custom_hooks = [
    xxx,
    dict(
        type='ExponentialMovingAverageHook',
        module_keys=('generator_ema', ),
        interval=8,
        start_iter=160000,
        interp_cfg=dict(momentum=0.9999, momentum_nontrainable=0.9999),
        priority='VERY_HIGH')
]

这里 interval 为 generator_ema 参数更新频率,start_inter 为 ema 开始 iteration,在此之前网络参数照搬 generator。interp_cfg 中 momentum 为 parameters 更新的权重,momentum_nontrainable 为 buffers 更新的权重。

对于模型训练的其他细节和具体实现,大家可以参考 MMGeneration 中的代码(当然我们也有可能再出一期详解~)

上面的设置,我们已经在 config 中为大家写好了,只需要运行如下代码,就可以开始训练自己的 BigGAN 模型了!

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
bash tools/dist_train.sh configs/biggan/biggan_ajbrock-sn_imagenet1k_128x128_b32x8_1500k.py 8 --work-dir ./work_dirs/biggan

训练过程中,可以查看 work_dirs/biggan/training_samples 下的不同阶段模型生成图片。前四行为 generator_ema 生成图片,后四行为 generator 同输入下生成的图片。

其实,上面的条件采样,条件插值,模型训练对于 MMGeneration 中已支持的 SNGAN,SAGAN 也是同样适用的,欢迎大家随时使用并提出意见~

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-03-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenMMLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
前端错误捕获方案总结
本文主要摘抄自:https://juejin.cn/post/7172072612430872584#heading-10,主要用来记录和学习,也推荐大家看看原博主的文章。
蓓蕾心晴
2022/12/30
1.7K0
前端错误捕获方案总结
从0到1搭建前端监控平台,面试必备的亮点项目
常常会苦恼,平常做的项目很普通,没啥亮点;面试中也经常会被问到:做过哪些亮点项目吗?
前端老道
2023/02/27
3.7K0
从0到1搭建前端监控平台,面试必备的亮点项目
JavaScript异常如何处理
在前端的开发工作当中,我们对于异常的处理可能关注的不是太多,因为js有基本的异常处理能力,很多错误会直接抛出来,打开控制台就能看到。但是如果因为异常导致网站卡死,甚至崩溃无法继续进行下去,对于用户的体验是相当差的,我们应该及时的捕获这些异常,对用户进行一些简要的温馨提示,并将异常进行及时的上报,以便于快速解决。
OECOM
2020/07/02
1.7K0
JavaScript异常如何处理
沉淀了3年的自研前端错误监控系统,打通你的脉络
这篇文章是我的好朋友广胤所写,里面记录了我们2018年探索的前端监控体系的历程,由于在建设完后的我离职了,后续也没有继续能和广胤一起更进一步的探索,还是有一些些遗憾。还记得我第一次进入「兑吧」的时候,我就在简历里描述了错误监控之类的项目,其实当时我并没有在一个公司进行过实践,这大概是之前在网易的时候,闲来没事,进行的自我探索。然后进入「兑吧」后,没想到当时公司正好缺少这一块的基建,于是 TL 就让我和广胤负责了这块项目,也是这次经历让我从实习阶段就正式踏入了前端基础建设的道路,还是非常感谢这一次的机会,让我从单一的业务开发人员,转化到了结构型开发人员。记得在开发的项目的那一个月中,除了吃饭,或者和广胤讨论项目的进度问题,近乎一种忘我的开发状态。
秋风的笔记
2021/07/30
1.1K0
前端魔法堂——异常不仅仅是try/catch
前言  编程时我们往往拿到的是业务流程正确的业务说明文档或规范,但实际开发中却布满荆棘和例外情况,而这些例外中包含业务用例的例外,也包含技术上的例外。对于业务用例的例外我们别无它法,必须要求实施人员与用户共同提供合理的解决方案;而技术上的例外,则必须由我们码农们手刃之,而这也是我想记录的内容。  我打算分成《前端魔法堂——异常不仅仅是try/catch》和《前端魔法堂——调用栈,异常实例中的宝藏》两篇分别叙述内置/自定义异常类,捕获运行时异常/语法异常/网络请求异常/PromiseRejection事件,
^_^肥仔John
2018/01/18
1.6K0
vue前端异常监控sentry实践
但是onerror事件无法捕获到网络异常的错误(资源加载失败、图片显示异常等),例如img标签下图片url 404 网络请求异常的时候,onerror无法捕获到异常,此时需要监听unhandledrejection。
csxiaoyao
2019/04/09
1.7K0
前端异常埋点系统初探
https://juejin.cn/post/6965022635470110733
@超人
2021/07/05
1.1K0
前端异常埋点系统初探
如何搭建前端异常监控系统
是指用户在使用应用时,无法得到预期的结果。不同的异常带来的后果程度不同,轻则引起用户使用不悦,重则导致产品无法使用,从而使用户丧失对产品的认可。
iOSSir
2020/07/24
1.9K0
如何搭建前端异常监控系统
web前端监控的三个方面探讨
以 init 为程序的入口,代码中所有同步执行出现的错误都会被捕获,这种方式也可以很好的避免程序刚跑起来就挂。
smy
2018/08/01
1.2K0
web前端监控的三个方面探讨
干货满满!如何做好前端日志和异常监控的思考
在研发过程中,日志是非常重要的一环,它可以帮助我们快速定位问题,解决问题。在前端开发中,日志也是非常重要的一环,它可以帮助我们快速定位问题,解决问题。本文将介绍前端日志的规范和最佳实践。但是我们经常看到一些项目日志打得满天飞,但是到了真正定位问题的时候,发现日志并没有什么卵用。这是因为日志打得不规范,不规范的日志是没有意义的。所以我们需要规范日志的打印,才能让日志发挥最大的作用。
老码小张
2024/03/20
1.6K1
干货满满!如何做好前端日志和异常监控的思考
让前端监控数据采集更高效
随着业务的快速发展,我们对生产环境下的问题感知能力越来越关注。作为距离用户最近的一层,前端的表现是否可靠、稳定、好用,很大程度上决定着用户对整个产品的体验和感受。因此,对于前端的监控不容忽视。
桃翁
2019/05/31
1.5K0
剖析前端异常及其降级处理和防范方案
“异常”一词出自《后汉书.卷一.皇后纪上.光烈阴皇后纪》,表示非正常的,不同于平常的。在我们现实生活中同样处处存在着异常,比如小县城里的路灯年久失修...,上下班高峰期深圳的地铁总是那么的拥挤...,人也总是时不时会生病等等; 由此可见,这个世界错误无处不在,这是一个基本的事实。
winty
2021/07/27
1.4K0
如何搭建前端异常监控系统
是指用户在使用应用时,无法得到预期的结果。不同的异常带来的后果程度不同,轻则引起用户使用不悦,重则导致产品无法使用,从而使用户丧失对产品的认可。
发声的沉默者
2021/06/14
1.3K0
如何搭建前端异常监控系统
一篇文章教你如何捕获前端错误
JavaScript代码在用户浏览器中执行时,由于一些边界情况、本地环境的不可控等因素,可能会存在js运行时错误。
2020labs小助手
2019/07/10
3.5K0
前端异常的捕获与处理
? 这是第 89 篇不掺水的原创,想要了解更多,请戳上方蓝色字体:政采云前端团队 关注我们吧~ 本文首发于政采云前端团队博客:前端异常的捕获与处理 https://www.zoo.team/arti
政采云前端团队
2021/03/16
3.8K0
前端异常的捕获与处理
前端异常捕获和定位
开发阶段,通过详细的报错信息,我们可以快速定位并解决问题。在生产,通过异常监控,根据异常埋点信息,我们可以第一时间知道异常信息,不至于造成严重后果。
GopalFeng
2020/09/24
1.4K0
前端异常捕获和定位
如何及时发现网页的隐形错误
先来说说JavaScript的错误类型,ECMA-262 定义了 7 种错误类型,说明如下:
zayyo
2023/11/02
3240
前端监控那些事
监控这个词对于前端,个人觉得有三个定义,分别是“性能监控”、“异常监控”、“数据监控” 性能监控则是针对web应用的性能,涉及包括用户体验、用户交互时间等 异常监控则是指Web应用得不到预期效果结果的情况监控 数据监控则是获取用户使用过程的行为数据反馈 1.性能监控 性能监控可以让我们更好的监控当前应用的性能情况,然后对性能情况反馈去做优化,性能会影响到用户体验,而常见的性能指标我们能通过浏览器Performance里面看到 1.1 Performace 允许访问当前页面性能相关的信息,perf
树酱
2020/07/03
1.3K0
如何优雅处理前端的异常?
前端一直是距离用户最近的一层,随着产品的日益完善,我们会更加注重用户体验,而前端异常却如鲠在喉,甚是烦人。
桃翁
2019/05/31
1.9K0
React,优雅的捕获异常
https://juejin.cn/post/6974383324148006926
winty
2021/07/27
8780
相关推荐
前端错误捕获方案总结
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验