GAN
这一概念是由Ian Goodfellow
于2014年提出,并迅速成为了非常火热的研究话题,GAN的变种更是有上千种,深度学习先驱之一的Yann LeCun
就曾说,"GAN及其变种是数十年来机器学习领域最有趣的idea
"。那么什么是GAN呢?GAN的应用有哪些呢?GAN的原理是什么呢?怎样去实现一个GAN呢?本文将一一阐述。具体大纲如下:
GAN的英文全称是Generative Adversarial Network
,中文名是生成对抗网络,它由两个部分组成,一个是生成器(generative),还有一个是鉴别器,与生成器是敌对(Adversarial)关系。对GAN有了初步了解,知道它有两个模块组成,下面通过事例来理解这两个模块的产生思想?
在生物进化的过程中,被捕食者会慢慢演化自己的特征,从而达到欺骗捕食者的目的,而捕食者也会根据情况调整自己对被捕食者的识别,共同进化,上图中的啵啵鸟和枯叶蝶就是这样的一种关系。生成器代表的是枯叶蝶,鉴别器代表的是啵啵鸟。它们的对抗思想与GAN类似,但GAN却有所不同。
GAN之所以有所不同,这里的原因是GAN所作的工作与自然界的生物进化不同,它是已经知道最终鉴别的目标是什么样子,不知道假目标是什么样子,它会对生成器所产生的假目标做惩罚和对真目标进行奖励,这样鉴别器就知道什么目标是不好的假目标,什么目标是好的真目标,而生成器则是希望通过进化,产生比上一次更好的假目标,使鉴别器对自己的惩罚更小。以上是一个轮回,下一个轮回,鉴别器通过学习上一个轮回进化的假目标和真目标,再次进化对假目标的惩罚,而生成器不屈不挠,再次进化,直到以假乱真,与真目标一致,至此进化结束。
以上图为例,我们最开始画人物头像只知道有一个头的大致形状,有眼睛有鼻子等等,但画得不精致,后来通过找老师学习,画得更好了,有模有样,直到,我们画得与专门画头像的老师一样好。这里的我们
就像是生成器
,一步步进化(对应生成器不同的等级),这里的老师
就像是鉴别器
(这里只是比喻说明,现实世界的老师已经是一个成熟的鉴别器,不需要通过假样本进行学习,这里有那个意思就行)
玩过纸牌的人知道,赢家的快乐是建立在输家的痛苦之上,收益和损失的总和始终为0。生成器和鉴别器也是这样一对博弈关系:鉴别器惩罚生成器,鉴别器收益,生成器损失;生成器进化,使鉴别器对自己惩罚小,生成器收益,鉴别器损失。
什么是GAN?GAN是由生成器和鉴别器两个部分组成,生成器的目的是生成假的目标,企图彻底骗过鉴别器的识别。而鉴别器通过学习真目标和假目标,提高自己的鉴别能力,不让假目标骗过自己。两者相互进化,相互博弈,一方进化,另一方损失,最后直到假目标与真目标很相似则停止进化。
首先,我们要知道结构化学习
(Structured Learning),GAN也是结构化学习的一种。与分类和回归类似,结构化学习也是需要找到一个X\(\rightarrow\)Y的映射,但结构化学习的输入和输出多种多样,可以是序列(sequence)到序列,序列到矩阵(matrix),矩阵到图(graph),图到树(tree)等等。这样,GAN的应用就十分广泛了。例如,机器翻译(machine translation)可以用GAN去做,如下图所示
还有语音识别(speech recognition)以及聊天机器人(chat-bot)
在图像方面,我们可以做图像转图像(image-to-image),彩色化(colorization),还有文本转图像(text-to-image)
当然,GAN的应用远不止这么些,有非常有趣的变脸,图像自动打马赛克,自动生成多表情图像,年轻转年老等等,更多cool又skr
的应用静待各位挖掘!
GAN的最终目的是为了生成能够产生以假乱真的目标的生成器。那么,是不是一定要用GAN呢?生成器可不可以自己训练得到目标?鉴别器可不可以自己训练得到目标?我们先来看这两个问题,然后再深入讨论GAN。
答案是肯定的,我们所熟知的自编码器
(Auto-Encoder)以及变分自编码器
(Variational Auto-Encoder)都是典型的生成器。输入通过Encoder编码成code,然后code通过Decoder重建原图,其中自编码器中的Decoder就是生成器,code可随机取值,产生不同的输出。
自编码器的结构如下:
变分自编码器的结构如下
然后自编码器存在着问题,我们来看看下面这张图
生成器的问题:由于自编码器的目标是让重建误差越来越小,但从上图中,我们可以看出,其中1个pixel的error,自编码器是觉得ok的,我们是觉得不行,另外6个pixel的误差我们觉得能接受的,自编码器不能接受,误差所在的位置很重要,而生成器并不知道这一点,自编码器缺少理解像素点之间的空间相关性的能力。还有一点,就是自编码器所产生的图像是模糊的,不能够产生十分清晰的图像,如下图所示
所以说目前单凭生成器是很难生成非常高质量的图像的。
答案也是肯定的。鉴别器是给定一个输入,输出一个[0,1]的置信度,越接近1则置信越高,越接近0则置信度越低,如图所示:
鉴别器的优势在于它可以很轻易地捕捉到元素之间的相关性,例如自编码器中出现的像素问题就不会在鉴别器中出现,如图所示,用一个滤波器就解决了。
现在来说说鉴别器要怎么样产生样本,参考下图:
首先也需要随机生成负样本,然后与真实样本一起送入鉴别器进行训练,在循环迭代中,通过最大概率选出最好的负样本,再与真样本一起送入鉴别器进行训练,然而,看起来和GAN训练差不多一致,没啥问题,其实这里面还有存在着问题的。我们来看下面这张图:
鉴别器的问题:鉴别器的训练是对真样本进行奖励,对负样本进行压低,也就是图中的绿色抬高,蓝色压低,这就造成了问题,我们要训练出好的鉴别器,训练过程需要随机采样出除绿色图像外所有的假样本,这样鉴别器就只会对真实样本的分布取高分,对其他分布取低分,这样才能训练的好,然后再高维空间中,这样的负样本采样过程其实是很难进行的,而且还有一个问题,生成样本的过程要枚举大量样本,才有可能出现一个与真样本分布相符的样本,通过求那个最大化概率问题求出最好的样本,这实在是过于繁琐。
通过上面的阐述,我们初步知道了它们的优缺点,下面这张ppt直观地给出了每个的优缺点,如图所示:
可以看出生成器和鉴别器的优缺点是可以互补的,这也就是GAN的优势。(生成器+鉴别器),下图介绍了GAN的优点,从两个角度出发。
当然,GAN也是又缺点的,它是一种隐变量模型,可解释没有生成器和鉴别器强,另外GAN是不好进行训练。我在训练DAGAN的时候就成功造成了鉴别器的误差为0,无法进行反向传播更新梯度。
对于生成器而言,它的目标是希望能够学习到真实样本的分布,这样就可以随机生成以假乱真的样本。如下图所示
如何去学习真实样本分布呢,这就需要用到极大似然估计
(Maximum Likelihood Estimation),先来看看下面这张图
我们需要随机采样真实分布中的数据,通过学习\(P(x;\theta)\)中的\(\theta\),希望\(P(x;\theta)\)越接近\(P_{data}(x)\),其中每一个\(x\)对应的\(P_{data}(x)\)的概率是很大的,为了使\(P(x;\theta)\)越接近\(P_{data}(x)\),原问题等价于最大化每一个\(P(x_i;\theta)\),合起来就是最大化\(\prod_{i=1}^mP_{G}(x^i;\theta)\)。而实际上极大似然估计是等价于最小化\(KL-divergence\),具体推导看下图,先取\(log\)(\(log\)是单调递增,不会改变原问题)将相乘化为相加,最后变成了\(P_{data}\)下\(logP_{G}(x;\theta)\)的期望,然后转化成积分的形式,后面加了一项\(\intop_xP_{data}(x)logP_{data}(x)dx\),这一项是一个常数,没有变量\(\theta\),加了也不会影响原问题的解,加了这一项之后原问题就等于最小化\(P_{data}和P_{G}\)的\(KL-divergence\)。
我们已经知道生成器要做的是\(arg\space \underset{G}{min}\space Div(P_{data},P_{G})\),这里\(P_{G}\)是我们要去最优化的,虽然我们有真实样本,但\(P_G\)的分布我们还是不知道,而且如何去定量计算\(P_{data}\)和\(P_G\)的\(divergence\),也就是\(Div(P_{data},P_G)\),我们也是不知道的。所以接下来就需要引入鉴别器了。 虽然我们不知道\(P_G\)和\(P_{data}\)的分布,但我们可以随机采样它们分布的样本,如下图所示:
而我们知道鉴别器的目标是给真样本奖励,假样本惩罚,如下图所示,最后得到要鉴别器要优化的目标函数,鉴别器希望能够最大化这个目标函数,也就是\(arg \space \underset{D}{max}\space V(D,G)\).注意,这里是是将\(G\)是\(fixed\),是不变的。
我们再来解这个问题,解出最优\(D^*\),接下来的步骤就比较数学了,给一个目标函数,求出极大值解。具体如图下
这个求解过程还是蛮详细的,最后我们竟然得到最大化\(V(D,G)\)竟然等于一个常数加上\(P_G\)和\(P_{data}\)的\(JS-divergence\)(\(JS-divergence\)与\(KL-divergence\)类似,不会改变解),这正是我们在生成器一直想求,可不会求得东西,鉴别器帮我们做到了。 于是,原始生成器的最优化问题\(arg\space\underset{G}{min}Div(P_G,P_{data})\)就可以转化成\(arg\space\underset{G}{min}\space \underset{D}{max}V(G,D)\)。那如何来求解\(arg\space\underset{G}{min}\space \underset{D}{max}V(G,D)\)这个最小最大问题呢?其实上面图上已经给出答案了,通过固定其中一个,求另一个,然后固定另一个,求之前固定住的这个。具体做法如图下:
更加详细的实践过程(也就是GAN的训练过程)如下所示,相信看了上面的一系列解释,会对GAN如此训练有了比较深的理解了吧。
GAN的理论就到此结束。
这里使用数据集是Anime——台大李宏毅老师的GAN课程的数据集,点击链接下载,首先我们来看一下DCGAN的框架,如图所示
这个是生成器的结构图,鉴别器的结构与生成器大致相反,DCGAN与普通的GAN有一些区别,具体分为下面几点
import torch
import torch.nn as nn
import torch.functional as F
class Generate(nn.Module):
def __init__(self, input_dim=100):
super(Generate, self).__init__()
channel = [512, 256, 128, 64, 3]
kernel_size = 4
stride = 2
padding = 1
self.convtrans1_block = self.__convtrans_bolck(input_dim, channel[0], 6, padding=0, stride=stride)
self.convtrans2_block = self.__convtrans_bolck(channel[0], channel[1], kernel_size, padding, stride)
self.convtrans3_block = self.__convtrans_bolck(channel[1], channel[2], kernel_size, padding, stride)
self.convtrans4_block = self.__convtrans_bolck(channel[2], channel[3], kernel_size, padding, stride)
self.convtrans5_block = self.__convtrans_bolck(channel[3], channel[4], kernel_size, padding, stride, layer="last_layer")
def __convtrans_bolck(self, in_channel, out_channel, kernel_size, padding, stride, layer=None):
if layer == "last_layer":
convtrans = nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, bias=False)
tanh = nn.Tanh()
return nn.Sequential(convtrans, tanh)
else:
convtrans = nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, bias=False)
batch_norm = nn.BatchNorm2d(out_channel)
relu = nn.ReLU(True)
return nn.Sequential(convtrans, batch_norm, relu)
def forward(self, inp):
x = self.convtrans1_block(inp)
x = self.convtrans2_block(x)
x = self.convtrans3_block(x)
x = self.convtrans4_block(x)
x = self.convtrans5_block(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
channels = [3, 64, 128, 256, 512]
kernel_size = 4
stride = 2
padding = 1
self.conv_bolck1 = self.__conv_block(channels[0], channels[1], kernel_size, stride, padding, "first_layer")
self.conv_bolok2 = self.__conv_block(channels[1], channels[2], kernel_size, stride, padding)
self.conv_bolok3 = self.__conv_block(channels[2], channels[3], kernel_size, stride, padding)
self.conv_bolok4 = self.__conv_block(channels[3], channels[4], kernel_size, stride, padding)
self.conv_bolok5 = self.__conv_block(channels[4], 1, kernel_size+1, stride, 0, "last_layer")
def __conv_block(self, inchannel, outchannel, kernel_size, stride, padding, layer=None):
if layer == "first_layer":
conv = nn.Conv2d(inchannel, outchannel, kernel_size, stride, padding, bias=False)
leakrelu = nn.LeakyReLU(0.2, inplace=True)
return nn.Sequential(conv, leakrelu)
elif layer == "last_layer":
conv = nn.Conv2d(inchannel, outchannel, kernel_size, stride, padding, bias=False)
sigmoid = nn.Sigmoid()
return nn.Sequential(conv, sigmoid)
else:
conv = nn.Conv2d(inchannel, outchannel, kernel_size, stride, padding, bias=False)
batchnorm = nn.BatchNorm2d(outchannel)
leakrelu = nn.LeakyReLU(0.2, inplace=True)
return nn.Sequential(conv, batchnorm, leakrelu)
def forward(self,inp):
x = self.conv_bolck1(inp)
x = self.conv_bolok2(x)
x = self.conv_bolok3(x)
x = self.conv_bolok4(x)
x = self.conv_bolok5(x)
return x
def weight_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0,0.01)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0,0.01)
m.bias.data.fill_(0)
if __name__ == "__main__":
model1 = Generate()
x = torch.randn(10,100,1,1)
y = model1.forward(x)
print(y.size())
model2 = Discriminator()
a = torch.randn(10,3,96,96)
b = model2.forward(a)
print(b.size())
import torch,torch.utils.data
import numpy as np
import scipy.misc, os
class AnimeDataset(torch.utils.data.Dataset):
def __init__(self, directory, dataset, size_per_dataset):
self.directory = directory
self.dataset = dataset
self.size_per_dataset = size_per_dataset
self.data_files = []
data_path = os.path.join(directory, dataset)
for i in range(size_per_dataset):
self.data_files.append(os.path.join(data_path,"{}.jpg".format(i)))
def __getitem__(self, ind):
path = self.data_files[ind]
img = scipy.misc.imread(path)
img = img.transpose(2,0,1)-127.5/127.5
return img
def __len__(self):
return len(self.data_files)
if __name__ == "__main__":
dataset = AnimeDataset(os.getcwd(),"anime",100)
loader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True,num_workers=4)
for i, inp in enumerate(loader):
print(i,inp.size())
import os, imageio,scipy.misc
import matplotlib.pyplot as plt
def creat_gif(gif_name, img_path, duration=0.3):
frames = []
img_names = os.listdir(img_path)
img_list = [os.path.join(img_path, img_name) for img_name in img_names]
for img_name in img_list:
frames.append(imageio.imread(img_name))
imageio.mimsave(gif_name, frames, 'GIF', duration=duration)
def visualize_loss(generate_txt_path, discriminator_txt_path):
with open(generate_txt_path, 'r') as f:
G_list_str = f.readlines()
with open(discriminator_txt_path, 'r') as f:
D_list_str = f.readlines()
D_list_float, G_list_float = [], []
for D_item, G_item in zip(D_list_str, G_list_str):
D_list_float.append(float(D_item.strip().split(':')[-1]))
G_list_float.append(float(G_item.strip().split(':')[-1]))
list_epoch = list(range(len(D_list_float)))
full_path = os.path.join(os.getcwd(), "saved/logging.png")
plt.figure()
plt.plot(list_epoch, G_list_float, label="generate", color='g')
plt.plot(list_epoch, D_list_float, label="discriminator", color='b')
plt.legend()
plt.title("DCGAN_Anime")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig(full_path)
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision.utils import make_grid
from model import Generate,Discriminator,weight_init
from AnimeDataset import AnimeDataset
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc
import os, argparse
from tqdm import tqdm
from utils import creat_gif, visualize_loss
def main():
parse = argparse.ArgumentParser()
parse.add_argument("--lr", type=float, default=0.0001,
help="learning rate of generate and discriminator")
parse.add_argument("--beta1", type=float, default=0.5,
help="adam optimizer parameter")
parse.add_argument("--batch_size", type=int, default=64,
help="number of dataset in every train or test iteration")
parse.add_argument("--dataset", type=str, default="faces",
help="base path for dataset")
parse.add_argument("--epochs", type=int, default=500,
help="number of training epochs")
parse.add_argument("--loaders", type=int, default=4,
help="number of parallel data loading processing")
parse.add_argument("--size_per_dataset", type=int, default=30000,
help="number of training data")
parse.add_argument("--pre_train", type=bool, default=False,
help="whether load pre_train model")
args = parse.parse_args()
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
if not os.path.exists("saved"):
os.mkdir("saved")
if not os.path.exists("saved/img"):
os.mkdir("saved/img")
if os.path.exists("faces"):
pass
else:
print("Don't find the dataset directory, please copy the link in website ,download and extract faces.tar.gz .\n \
https://1drv.ms/u/s!AgBYzHhocQD4g0_Fr-mC-DYfWahJ \n ")
exit()
if args.pre_train:
generate = torch.load("saved/generate.t7").to(device)
discriminator = torch.load("saved/discriminator.t7").to(device)
else:
generate = Generate().to(device)
discriminator = Discriminator().to(device)
generate.apply(weight_init)
discriminator.apply(weight_init)
dataset = AnimeDataset(os.getcwd(), args.dataset, args.size_per_dataset)
dataload = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
criterion = nn.BCELoss().to(device)
optimizer_G = Adam(generate.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
optimizer_D = Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
fixed_noise = torch.randn(64, 100, 1, 1).to(device)
for epoch in range(args.epochs):
print("Main epoch{}:".format(epoch))
progress = tqdm(total=len(dataload.dataset))
loss_d, loss_g = 0, 0
for i, inp in enumerate(dataload):
# train discriminator
real_data = inp.float().to(device)
real_label = torch.ones(inp.size()[0]).to(device)
noise = torch.randn(inp.size()[0], 100, 1, 1).to(device)
fake_data = generate(noise)
fake_label = torch.zeros(fake_data.size()[0]).to(device)
optimizer_D.zero_grad()
real_output = discriminator(real_data)
real_loss = criterion(real_output.squeeze(), real_label)
real_loss.backward()
fake_output = discriminator(fake_data)
fake_loss = criterion(fake_output.squeeze(), fake_label)
fake_loss.backward()
loss_D = real_loss + fake_loss
optimizer_D.step()
#train generate
optimizer_G.zero_grad()
fake_data = generate(noise)
fake_label = torch.ones(fake_data.size()[0]).to(device)
fake_output = discriminator(fake_data)
loss_G = criterion(fake_output.squeeze(), fake_label)
loss_G.backward()
optimizer_G.step()
progress.update(dataload.batch_size)
progress.set_description("D:{}, G:{}".format(loss_D.item(), loss_G.item()))
loss_g += loss_G.item()
loss_d += loss_D.item()
loss_g /= (i+1)
loss_d /= (i+1)
with open("generate_loss.txt", 'a+') as f:
f.write("loss_G:{} \n".format(loss_G.item()))
with open("discriminator_loss.txt", 'a+') as f:
f.write("loss_D:{} \n".format(loss_D.item()))
if epoch % 20 == 0:
torch.save(generate, os.path.join(os.getcwd(), "saved/generate.t7"))
torch.save(discriminator, os.path.join(os.getcwd(), "saved/discriminator.t7"))
img = generate(fixed_noise).to("cpu").detach().numpy()
display_grid = np.zeros((8*96,8*96,3))
for j in range(int(64/8)):
for k in range(int(64/8)):
display_grid[j*96:(j+1)*96,k*96:(k+1)*96,:] = (img[k+8*j].transpose(1, 2, 0)+1)/2
img_save_path = os.path.join(os.getcwd(),"saved/img/{}.png".format(epoch))
scipy.misc.imsave(img_save_path, display_grid)
creat_gif("evolution.gif", os.path.join(os.getcwd(),"saved/img"))
visualize_loss("generate_loss.txt", "discriminator_loss.txt")
if __name__ == "__main__":
main()
代码运行请参考github的[readme][https://github.com/FangYang970206/Anime_GAN],最后500个epoch的结果图如下
WGAN pytorch版本一直都有bug,目前还没找到原因,实现了一个keras版本的,代码如下(运行前记得看readme):
import os,scipy.misc
import keras.backend as K
from keras.models import Sequential, Model
from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input
from keras.layers import Conv2DTranspose, Reshape, Activation, Cropping2D, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import RMSprop
from keras.activations import relu
from keras.initializers import RandomNormal
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
conv_init = RandomNormal(0, 0.02)
gamma_init = RandomNormal(1., 0.02)
os.environ['KERAS_BACKEND']='tensorflow'
os.environ['TENSORFLOW_FLAGS']='floatX=float32,device=cuda'
def DCGAN_D(isize, nc, ndf):
inputs = Input(shape=(isize, isize, nc))
x = ZeroPadding2D()(inputs)
x = Conv2D(ndf, kernel_size=4, strides=2, use_bias=False, kernel_initializer=conv_init)(x)
x = LeakyReLU(alpha=0.2)(x)
for _ in range(4):
x = ZeroPadding2D()(x)
x = Conv2D(ndf*2, kernel_size=4, strides=2, use_bias=False, kernel_initializer=conv_init)(x)
x = BatchNormalization(epsilon=1.01e-5, gamma_init=gamma_init)(x, training=1)
x = LeakyReLU(alpha=0.2)(x)
ndf *= 2
x = Conv2D(1, kernel_size=3, strides=1, use_bias=False, kernel_initializer=conv_init)(x)
outputs = Flatten()(x)
return Model(inputs=inputs, outputs=outputs)
def DCGAN_G(isize, nz, ngf):
inputs = Input(shape=(nz,))
x = Reshape((1, 1, nz))(inputs)
x = Conv2DTranspose(filters=ngf, kernel_size=3, strides=2, use_bias=False,
kernel_initializer = conv_init)(x)
for _ in range(4):
x = Conv2DTranspose(filters=int(ngf/2), kernel_size=4, strides=2, use_bias=False,
kernel_initializer = conv_init)(x)
x = Cropping2D(cropping=1)(x)
x = BatchNormalization(epsilon=1.01e-5, gamma_init=gamma_init)(x, training=1)
x = Activation("relu")(x)
ngf = int(ngf/2)
x = Conv2DTranspose(filters=3, kernel_size=4, strides=2, use_bias=False,
kernel_initializer = conv_init)(x)
x = Cropping2D(cropping=1)(x)
outputs = Activation("tanh")(x)
return Model(inputs=inputs, outputs=outputs)
nc = 3
nz = 100
ngf = 1024
ndf = 64
imageSize = 96
batchSize = 64
lrD = 0.00005
lrG = 0.00005
clamp_lower, clamp_upper = -0.01, 0.01
netD = DCGAN_D(imageSize, nc, ndf)
netD.summary()
netG = DCGAN_G(imageSize, nz, ngf)
netG.summary()
clamp_updates = [K.update(v, K.clip(v, clamp_lower, clamp_upper))
for v in netD.trainable_weights]
netD_clamp = K.function([],[], clamp_updates)
netD_real_input = Input(shape=(imageSize, imageSize, nc))
noisev = Input(shape=(nz,))
loss_real = K.mean(netD(netD_real_input))
loss_fake = K.mean(netD(netG(noisev)))
loss = loss_fake - loss_real
training_updates = RMSprop(lr=lrD).get_updates(netD.trainable_weights,[], loss)
netD_train = K.function([netD_real_input, noisev],
[loss_real, loss_fake],
training_updates)
loss = -loss_fake
training_updates = RMSprop(lr=lrG).get_updates(netG.trainable_weights,[], loss)
netG_train = K.function([noisev], [loss], training_updates)
fixed_noise = np.random.normal(size=(batchSize, nz)).astype('float32')
datagen = ImageDataGenerator(
# featurewise_center=True,
# featurewise_std_normalization=True,
rotation_range=20,
rescale=1./255
)
train_generate = datagen.flow_from_directory("faces/", target_size=(96,96), batch_size=64,
shuffle=True, class_mode=None, save_format='jpg')
step = 0
print(dir(train_generate))
for step in range(100000):
for _ in range(5):
real_data = (np.array(train_generate.next())*2-1)
noise = np.random.normal(size=(batchSize, nz))
errD_real, errD_fake = netD_train([real_data, noise])
errD = errD_real - errD_fake
netD_clamp([])
noise = np.random.normal(size=(batchSize, nz))
errG, = netG_train([noise])
print('[%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f' % (step, errD, errG, errD_real, errD_fake))
if step%1000==0:
netD.save("discriminator.h5")
netG.save("generate.h5")
fake = netG.predict(fixed_noise)
display_grid = np.zeros((8*96,8*96,3))
for j in range(int(64/8)):
for k in range(int(64/8)):
display_grid[j*96:(j+1)*96,k*96:(k+1)*96,:] = fake[k+8*j]
img_save_path = os.path.join(os.getcwd(),"saved/img/{}.png".format(step))
scipy.misc.imsave(img_save_path, display_grid)
代码运行请参考github的[readme][https://github.com/FangYang970206/Anime_GAN],100000step的结果:
1.对真实图片进行归一化,与生成图片分布一样,也就是[-1,1]. 2.随机噪声使用高斯分布,不要使用均匀分布,也就是在代码中使用torch.randn,而不是torch.rand 3.初始化权重很有必要,详细见model.py中的weight_init函数 4.在训练时,在鉴别器中产生的noise,生成器也要用这个noise进行参数,这点很重要。我最开始的时候就是鉴别器随机产生noise,生成器也随机产生noise,训练得很不好。 5.在训练过程中,很有可能鉴别器的loss等于0(鉴别器太强了,起初我试过减小鉴别器的学习率,但还是会有这个情况,我猜想原因是在某一个batch中,鉴别器恰好将随机噪声产生的图片和真实图片完全区分开,loss为0),导致生成器崩溃(梯度弥散),所以最好按多少个epoch保存模型,然后在导入模型再训练。个人觉得数据增强和增大batchsize会减弱这种情况的可能性,这个还未实践。
1 李宏毅GAN课程及PPT 2 DCGAN paper 3 chenyuntc