经典论文复现LSGAN:最小二乘生成对抗网络

可验证的知识是科学的基础,它事关理解。随着人工智能领域的发展,打破不可复现性将是必要的。为此,PaperWeekly联手百度PaddlePaddle共同发起了本次论文有奖复现,我们希望和来自学界、工业界的研究者一起接力,为AI行业带来良性循环。

作者丨文永亮

学校丨华南理工大学

研究方向丨目标检测、图像生成

笔者这次选择复现的是LeastSquaresGenerativeAdversarialNetworks,也就是LSGANs。

LSGANs这篇经典的论文主要工作是把交叉熵损失函数换做了最小二乘损失函数,这样做作者认为改善了传统GAN的两个问题,即传统GAN生成的图片质量不高,而且训练过程十分不稳定。

LSGANs试图使用不同的距离度量来构建一个更加稳定而且收敛更快的,生成质量高的对抗网络。但是我看过WGAN的论文之后分析这一损失函数,其实并不符合WGAN作者的分析。在下面我会详细分析一下为什么LSGANs其实并没有那么好用。

论文复现代码:

LSGANs的优点

我们知道传统GAN生成的图片质量不高,传统的GANs使用的是交叉熵损失(sigmoidcrossentropy)作为判别器的损失函数。

在这里说一下我对交叉熵的理解,有两个分布,分别是真实分布p和非真实分布q。

信息熵是

,就是按照真实分布p这样的样本空间表达能力强度的相反值,信息熵越大,不确定性越大,表达能力越弱,我们记作H(p)。交叉熵就是

,可以理解为按照不真实分布q这样的样本空间表达能力强度的相反值,记作H(p,q)。

KL散度就是D(p||q)=H(p,q)-H(p),它表示的是两个分布的差异,因为真实分布p的信息熵固定,所以一般由交叉熵来决定,所以这就是为什么传统GAN会采用交叉熵的缘故,论文也证明了GAN损失函数与KL散度的关系。

我们知道交叉熵一般都是拿来做逻辑分类的,而像最小二乘这种一般会用在线性回归中,这里为什么会用最小二乘作为损失函数的评判呢?

使用交叉熵虽然会让我们分类正确,但是这样会导致那些在决策边界被分类为真的、但是仍然远离真实数据的假样本(即生成器生成的样本)不会继续迭代,因为它已经成功欺骗了判别器,更新生成器的时候就会发生梯度弥散的问题。

论文指出最小二乘损失函数会对处于判别成真的那些远离决策边界的样本进行惩罚,把远离决策边界的假样本拖进决策边界,从而提高生成图片的质量。作者用下图详细表达了这一说法:

我们知道传统GAN的训练过程十分不稳定,这很大程度上是因为它的目标函数,尤其是在最小化目标函数时可能发生梯度弥散,使其很难再去更新生成器。而论文指出LSGANs可以解决这个问题,因为LSGANs会惩罚那些远离决策边界的样本,这些样本的梯度是梯度下降的决定方向。

论文指出因为传统GAN辨别器D使用的是sigmoid函数,并且由于sigmoid函数饱和得十分迅速,所以即使是十分小的数据点x,该函数也会迅速忽略样本x到决策边界w的距离。这就意味着sigmoid函数本质上不会惩罚远离决策边界的样本,并且也说明我们满足于将x标注正确,因此辨别器D的梯度就会很快地下降到0。

LSGANs的损失函数

传统GAN的损失函数:

LSGANs的损失函数:

其中G为生成器(Generator),D为判别器(Discriminator),z为噪音,它可以服从归一化或者高斯分布,为真实数据x服从的概率分布,为z服从的概率分布。为期望值,同为期望值。

defgenerator(z,name="G"):

withfluid.unique_name.guard(name+'_'):

fc1=fluid.layers.fc(input=z,size=1024)

fc1=fluid.layers.fc(fc1,size=128*7*7)

fc1=fluid.layers.batch_norm(fc1,act='tanh')

fc1=fluid.layers.reshape(fc1,shape=(-1,128,7,7))

conv1=fluid.layers.conv2d(fc1,num_filters=4*64,

filter_size=5,stride=1,

padding=2,act='tanh')

conv1=fluid.layers.reshape(conv1,shape=(-1,64,14,14))

conv2=fluid.layers.conv2d(conv1,num_filters=4*32,

conv2=fluid.layers.reshape(conv2,shape=(-1,32,28,28))

conv3=fluid.layers.conv2d(conv2,num_filters=1,

#conv3=fluid.layers.reshape(conv3,shape=(-1,1,28,28))

print("conv3",conv3)

returnconv3

▲生成器代码展示

defdiscriminator(image,name="D"):

conv1=fluid.layers.conv2d(input=image,num_filters=32,

filter_size=6,stride=2,

padding=2)

conv1_act=fluid.layers.leaky_relu(conv1)

conv2=fluid.layers.conv2d(conv1_act,num_filters=64,

conv2=fluid.layers.batch_norm(conv2)

conv2_act=fluid.layers.leaky_relu(conv2)

fc1=fluid.layers.reshape(conv2_act,shape=(-1,64*7*7))

fc1=fluid.layers.fc(fc1,size=512)

fc1_bn=fluid.layers.batch_norm(fc1)

fc1_act=fluid.layers.leaky_relu(fc1_bn)

fc2=fluid.layers.fc(fc1_act,size=1)

print("fc2",fc2)

returnfc2

▲判别器代码展示

作者提出了两种abc的取值方法:

1.使b-c=1,b-a=2,例如a=-1,b=1,c=0:

2.使c=b,用0-1二元标签,我们可以得到:

作者在文献中有详细推倒过程,详细说明了LSGAN与f散度之间的关系,这里简述一下。

通过对下式求一阶导可得到D的最优解:

代入:

其中另加项并不影响的值,因为它不包含参数G。

最后我们设b-c=1,b-a=2就可以得到:

其中就是皮尔森卡方散度。

LSGANs未能解决的地方

下面我会指出LSGANs给出的损失函数到底符不符合WGAN前作的理论。关于WGAN前作及WGAN论文的分析可以参考本文[5]。

上面我们指出了D的最优解为公式(5),我们最常用的设a=-1,b=1,c=0可以得出:

把最优判别器带入上面加附加项的生成器损失函数可以表示为:

也就是优化上面说的皮尔森卡方散度,其实皮尔森卡方散度和KL散度、JS散度有一样的问题,根据WGAN给出的理论,下面用P1,P2分别表示和。

当P1与P2的支撑集(support)是高维空间中的低维流形(manifold)时,P1与P2重叠部分测度(measure)为0的概率为1。也就是P1和P2不重叠或重叠部分可忽略的可能性非常大。

对于数据点x,只可能发生如下四种情况:

1.P1(x)=0,P2(x)=0

2.P1(x)!=0,P2(x)!=0

3.P1(x)=0,P2(x)!=0

4.P1(x)!=0,P2(x)=0

可以想象成下面这幅图,假设P1(x)分布就是AB线段,P2(x)分布就是CD线段,数据点要么在两条线段的其中一条,要么都不在,同时在两条线段上的可能性忽略不计。

情况1是没有意义的,而情况2由于重叠部分可忽略的可能性非常大所以对计算损失贡献为0,情况3可以算出D*=-1,损失是个定值1,情况4类似。

所以我们可以得出结论,当P1和P2不重叠或重叠部分可忽略的可能性非常大时,当判别器达到最优时,生成器仍然是不迭代的,因为此时损失是定值,提供的梯度仍然为0。同时我们也可以从另一个角度出发,WGAN的Wasserstein距离可以变换如下:

它要求函数f要符合Lipschitz连续,可是最小二乘损失函数是不符合的,他的导数是没有上界的。所以结论就是LSGANs其实还是未能解决判别器足够优秀的时候,生成器还是会发生梯度弥散的问题。

两种模型架构和训练

模型的结构

作者也提出了两类架构:

第一种处理类别少的情况,例如MNIST、LSUN。网络设计如下:

第二类处理类别特别多的情形,实际上是个条件版本的LSGAN。针对手写汉字数据集,有3740类,提出的网络结构如下:

训练数据

论文中使用了很多场景的数据集,然后比较了传统GANs和LSGANs的稳定性,最后还通过训练3740个类别的手写汉字数据集来评价LSGANs。

▲本文使用的数据集列表

在LSUN和HWDB1.0的这两个数据集上使用LSGANs的效果图如下,其中LSUN使用了里面的bedroom,kitchen,church,diningroom和conferenceroom五个场景,bedroom场景还对比了DCGANs和EBGANs的效果在图5中,可以观察到LSGANs生成的效果要比那两种的效果好。

图7则体现了LSGANs和传统GANs生成的图片对比。

通过实验观察,作者发现4点技巧:

1.生成器G带有batchnormalization批处理标准化(以下简称BN)并且使用Adam优化器的话,LSGANs生成的图片质量好,但是传统GANs从来没有成功学习到,会出现modecollapse现象;

2.生成器G和判别器D都带有BN层,并且使用RMSProp优化器处理,LSGANs会生成质量比GANs高的图片,并且GANs会出现轻微的modecollapse现象;

3.生成器G带有BN层并且使用RMSProp优化器,生成器G判别器D都带有BN层并且使用Adam优化器时,LSGANs与传统GANs有着相似的表现;

4.RMSProp的表现比Adam要稳定,因为传统GANs在G带有BN层时,使用RMSProp优化可以成功学习,但是使用Adam优化却不行。

下面是使用LSGANs和GANs学习混合高斯分布的数据集,下图展现了生成数据分布的动态结果,可以看到传统GAN在Step15k时就会发生modecollapse现象,但LSGANs非常成功地学习到了混合高斯分布。

论文具体实现

笔者使用了MNIST数据集进行实验,具体实现效果如下:

LSGANs:

GAN:

从本次用MNIST数据训练的效果来看,LSGANs生成的效果似乎是比GAN的要清晰高质量一些。

总结

LSGANs是对GAN的一次优化,从实验的情况中,笔者也发现了一些奇怪的现象。我本来是参考论文把判别器D的损失值,按真假两种loss加起来一并放入Adam中优化,但是无论如何都学习不成功,梯度还是弥散了,最后把D_fake_loss和D_real_loss分为两个program,放入不同的Adam中优化判别器D的参数才达到预期效果。

这篇论文中的思想是非常值得借鉴的,从最小二乘的距离的角度考量,并不是判别器分类之后就完事了,但是LSGANs其实还是未能解决判别器足够优秀的时候,生成器梯度弥散的问题。

关于PaddlePaddle

笔者反馈:帮助文档有点少,而且我本来就直接写好了想改成使用GPU运算,没找到怎么改;

PaddlePaddle团队:关于如何使用GPU运行,可以看下执行器Executor(单GPU或单线程CPU执行器)或ParallelExecutor(多GPU或多线程CPU执行器,也可以单GPU/线程CPU执行)的文档,前者指定place为CUDAPlace,后者接口有个use_cuda,具体请参考文档。也可以看modelsrepo例子,比如image_classification或text_classification的例子。

笔者反馈:Program这个概念有点新颖,一个模型可以有多个Program,但是我实现的GAN可以只用一个,也可以分别放进三个Program,没有太了解到Program这个概念的优越之处,我还是像计算图那样使用了,官方也没给出与TensorFlow的对比。

PaddlePaddle团队:关于Program设计可以参考官方文档。这里提一点,在用户使用的直观感受中和TensorFlowgraph不同的是,凡是放在一个Program里op,只要运行该Program,这些op就都会执行;而TensorFlow,指定一个variable,只运行以该variable为叶子节点的graph,其他多余node不执行,这是最大的用户感受到的区别。

小道消息:听说全新版本的PaddlePaddle已于今日发布哦。

参考文献

[1].I.Goodfellow,J.Pouget-Abadie,M.Mirza,B.Xu,D.Warde-Farley,S.Ozair,A.Courville,andY.Bengio,“Generativeadversarialnets,”inAdvancesinNeuralInformationProcessingSystems(NIPS),pp.2672–2680,2014.

[2].M.Arjovsky,S.Chintala,andL.Bottou.WassersteinGAN.arXivpreprintarXiv:1701.07875,2017.

[3].AndrewBrock,JeffDonahueandKarenSimonyan.LargeScaleGANTrainingforHighFidelityNaturalImageSynthesis.arXiv:1809.11096,2018.

[4].Schlegl,Thomas,etal."UnsupervisedAnomalyDetectionwithGenerativeAdversarialNetworkstoGuideMarkerDiscovery."arXivpreprintarXiv:1703.05921(2017).

点击标题查看更多论文解读:

#

稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)

THE END
1.(casiahwdb)汉字识别数据集The online and offline Chinese handwriting databases, CASIA-OLHWDB and CASIA-HWDB, were built by the National Laboratory of Pattern Recognition (NLPR), Institute of Automation of Chinese Academy of Sciences (CASIA). The handwritten samples were produced by 1,020 writers using Anoto pen on papershttp://www.nlpr.ia.ac.cn/databases/handwriting/Home.html
2.keras+卷积神经网络HWDB手写汉字识别keras+卷积神经网络HWDB手写汉字识别 写在前面 HWDB手写汉字数据集来自于中科院自动化研究所,下载地址: http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.ziphttps://blog.csdn.net/yql_617540298/article/details/82251994
3.celeba数据集CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. The images in this dataset cover large pose variations and background clutter. CelebA has large diversities, large quantities, and rich annotationshttp://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
4.手写汉字数据集(部分)手写汉字数据集(HWDB1.1),图片形式的各种汉字以经分别在各个文件夹内存储好。 手写汉字 数据集2018-09-05 上传大小:42.00MB 所需:43积分/C币 CNN卷积神经网络识别手写汉字MNIST数据集.zip 这是我修改的别人的代码,别人的代码有点问题,我修改了一下,代码的正确率很高,可达90%以上,这是一个5层卷积神经网络的代https://www.iteye.com/resource/qq_27280237-10648261
5.Gbase8a数据库安装与使用HWDB-1.1 手写汉字CNN识别模型训练 数据集 使用CASIA-HWDB1.1进行训练和测试,训练集和测试集按照4:1划分,测试集235200张,训练集940800张, 共计1,176,000张图像。该数据集由300个人手写而成,其中包含171个阿拉伯数字和特殊符号,3755类GB2312-80 level-1汉字。 http://www.nlpr.ia.ac.cn/databases/handwriting/https://www.pianshen.com/article/7084303285/
6.基于机器学习的方法实现手写数据集识别系统手写字体识别数据集下载HWDB1.1数据集: 1. $ wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip 2. # zip解压没得说, 之后还要解压alz压缩文件 3. $ wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip https://blog.51cto.com/u_16213702/8807334
7.使用python获取CASIA脱机和在线手写汉字库CASIA-HWDB CASIA-OLHWDB 在申请书中介绍了数据集的基本情况: >CASIA-HWDB和CASIA-OLHWDB数据库由中科院自动化研究所在 2007-2010 年间收集, 均各自包含 1,020 人书写的脱机(联机)手写中文单字样本和手写文本, 用 Anoto 笔在点阵纸上书写后扫描、分割得到。 https://www.imooc.com/article/40759
8.CASIAHWDB脱机手写汉字数据集以及申请表下载我真的找遍全网,总算是找到了这个数据集,现在分享给大家。共六个文件,分别是CASIA-HWDB1.0训练集和测试集、CASIA-HWDB1.1训练集和测试集、CASIA-Competition数据集还有一张申请表。不过我看大多数人都是把前四个文件合并起来当做训练集,用Competition那个做测试集的。【注:2019年春节期间数据集的官网打不开,现在https://www.jianshu.com/p/980e2528e8fe