模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

彩虹网

https://github.com/huawei-noah/DAFL

1.1 DAFL 原理分析:

华为诺亚方舟实验室联合北京大学和悉尼大学提出了在无数据情况下的网络蒸馏方法 DAFL,比之前的最好算法在MNIST上提升了6个百分点,并且使用 resnet18 在 CIFAR-10 和 100 上分别达到了 92% 和 74% 的准确率 (无需训练数据)。

它的特点是:

主要步骤是:

通过待压缩网络训练生成器

通过生成器 输出生成图片作为训练样本

通过训练样本蒸馏待压缩网络 得到压缩后的网络

1.11 知识蒸馏方法获得学生网络

蒸馏算法最早由Hinton提出,待压缩网络 (教师网络) 为一个具有高准确率但参数很多的神经网络,初始化一个参数较少的学生网络,通过让学生网络的输出和教师网络相同,学生网络的准确率在教师的指导下得到提高。

从结构和参数的角度看,如上文所述,待压缩的大网络的结构和参数都是未知的,这就使得我们无法通过剪枝或者量化等经典的神经网络压缩方法进行模型压缩,我们唯一已知的就是待压缩的大网络的输入和输出接口。

从训练数据的角度看,DAFL 的训练样本是由生成器 生成的,是没有标签的,所以没法通过有监督的方式学习学生网络,基于这两点,作者引入了教师学生网络学习范式,利用蒸馏算法实现利用未标注生成样本对黑盒网络的压缩。

令 和 分别代表教师和学生网络,则作者使用交叉熵损失来使得学生网络的输出符合教师网络的输出,具体的损失函数为:

式中, 指交叉熵损失函数, 和 分别是教师和学生网络的输出。通过引入教师学生算法,作者解决了生成图片没有标签的问题,并且可以在待压缩网络结构未知的情况下对其进行压缩。

1.12 通过 GAN 生成无标注的训练图片

从训练数据的角度看,在整个网络压缩的过程中,我们都没有任何给定的训练数据,在此情况下,神经网络的压缩变得十分困难。所以作者通过 GAN 来输出一些无标注的训练图片,以便于神经网络的压缩。生成对抗网络 (GAN) 是一种可以生成数据的方法,包含生成网络 与判别网络 ,生成网络希望输出和真实数据类似的图片来骗过判别器,判别网络通过判别生成图片和真实图片的真伪来帮助生成网络训练。

具体而言,给定一个任意的噪声向量 (noise vector) ,生成器 会把它映射成虚假的图片 即 。另一方面, 判别器 要区分来的一张图片是真实的 还是生成器伪造的 , 所以,对于 GAN 而言,它的目标函数可以写成:

这个目标函数 的优化方法是 。就是每轮优化分为2步,第1步是通过gradient ascent 优化 的参数,第2步是通过 gradient descent 优化 的参数。然而,我们会 发现传统的 GAN 需要基于真实数据 来训练判别器,这对于我们来说是无法进行的。所以基于传 统的 GAN 训练方法 2 式是不行的。

许多研究表明,训练好的判别器 具有提取图像特征的能力,提取到的特征可以直接用于分类任务,所以,由于待压缩网络使用真实图片进行训练,也同样具有提取特征的能力,从而具有一定的分辨图像真假的能力。而且这个待压缩网络我们是已有的。于是,我们把待压缩网络作为一个固定的判别器 ,以此来训练我们的生成网络 。

首先,待压缩网络作为一个固定的判别器 ,我们就认为它是已经训练好参数的判别器 ,我们利用它来训练生成器的基本思想是下式:

式中, 就是已经训练好参数的判别器,生成器 的参数经过3式持续优化使得 逐渐上升,代表着生成器的输出越来越能够骗过判别器。

但是,在传统GAN中,传统的判别器 的输出是判定图片是否真假 (Real or Fake?),只要让生成网络生成在判别器中分类为真的图片即可训练,但是,我们的待压缩网络为分类网络,其输出是分类结果 (1-num_classes),所以待压缩网络无法直接作为一个固定的判别器 。因此需要重新设计生成网络的目标。通过观察真实图片在分类网络的响应,作者提出了以下损失函数。

1) 伪标签交叉熵损失

在图像分类任务中,神经网络的训练采用的是交叉熵损失函数,在训练完成后,真实图片在网络中的输出将会是一个one-hot的向量,即分类类别对应的输出为1,其他的输出为0。于是,我们希望生成图片也具有类似的性质。给定一组任意的噪声向量 ,它们通过生成器 之后得到的生成图片是 ,这里 。

现在把这些生成图片 输入给待压缩的网络,通过 得到输出 ,预测标签就是通过 计算得到 。定义伪标签交叉熵损失为:

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

其中 就是标准的交叉熵函数,由于生成图片并没有一个真实的标签,我们直接将其输出最大值对应的标签设定为它的伪标签。 伪标签交叉熵损失的意思就是对于一张生成的图片,它的标签就按照教师网络的输出来决定,这是训练生成器 的第1个损失。

2) 特征激活损失函数

在神经网络的训练中,由卷积核提取的特征也是输入图片的一种重要表示。先前的许多工作表明,卷积核提取的特征包含着图片的许多重要信息,将训练数据输入训练好的深度网络中,卷积核会产生更大的响应 (相比于噪声或与此网络无关的数据),基于此,作者提出了特征激活损失函数。定义生成图片 经过教师网络得到的特征是 ,则特征激活损失函数定义为:

反向传播优化生成器参数的方法是:

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

因为待压缩网络 (即教师网络) 是训练好的,所以目标是让生成图像在待压缩网络中的特征响应值更大,来使图片更接近训练数据。这里作者采用了1范数来优化,原因是1范数相比于2范数会产生更加稀疏的值,而神经网络的响应也常常是稀疏的。

3) 信息熵损失函数

为了让神经网络更好的训练,真实的训练数据对于每个类别的样本数目通常都保持一致,例如MNIST每个类别都含有 6000 张图片。于是,为了让生成网络产生各个类别样本的概率基本相同,作者引入信息熵,信息熵是针对一个概率分布而言的。假设现在有概率分布 ,概率分布 的信息熵的计算方法就是:

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

概率分布 越均匀,信息熵 就越小。极限情况当 时,信息熵 取极大值 。所以信息熵损失函数定义为:

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

其中 为标准的信息熵,信息熵的值越大,对于生成的一组样本经过待压缩教师网络的输出特征 来讲,每个类别的数目就越平均,从而保证了生成样本的类别平均。

反向传播优化生成器参数的方法是:

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

最后,我们将这三个损失函数 (4,5,9式) 组合起来,就可以得到我们生成器总的损失函数:

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

通过优化以上的损失函数,训练得到的生成器可以和真实的样本在待压缩网络具有类似的响应,从而更接近真实样本,且生成的数据的分布十分均匀。

DAFL 的流程和算法如下图1和图2所示。把待压缩网络当做判别器 ,通过上式 12 作为损失函数来训练生成器 。通过生成器 来得到足够的生成图片,这些图片的分布与训练教师网络的训练数据是一致的。然后,再通过上式 1 的蒸馏损失和这些生成图片对教师网络进行蒸馏得到学生网络。

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

图1:DAFL 框架

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

图2:DAFL 算法

1.13 实验结果

作者在MNIST、CIFAR、CelebA三个数据集上分别进行了实验。

MNIST 实验

MNIST 数据集: 10类,60000 training+10000 testing。

作者实验了卷积模型和全连接模型,卷积模型使用 LeNet-5。全连接模型使用 Hinton 提出的具有3个全连接层的网络 Hinton-784-1200-1200-10 作为待压缩模型,将他们的通道数目减半分别作为学生模型 (LeNet-5-HALF 和 Hinton-784-800-800-10)。

图3的前三行是在原始数据集的实验结果。我们以 LeNet-5 模型为例。

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

图3:MNIST 数据集实验结果

CIFAR 实验

CIFAR-10 数据集: 10类,50000 training+10000 testing。

CIFAR-100 数据集: 100类,50000 training+10000 testing。

作者还在 CIFAR-10 和 CIFAR-100 数据集上进行了实验,使用的教师和学生模型分别为 Resnet-34 和 Resnet-18。

图3的前三行是在原始数据集的实验结果。我们以 CIFAR-10 数据集的结果为例。

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

图4:CIFAR 数据集实验结果

CelebA 实验

CelebA 数据集:202599 training images

作者又在 CelebA 数据集上进行了实验,使用的教师和学生模型分别为 AlexNet 和 AlexNet-Half。GAN 模型取 DCGAN。

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

图5:CelebA 数据集实验结果

对比实验

由于我们的方法由很多损失函数组成,我们通过消融实验来分析每个损失函数项的必要性。对比试验的数据集是 MNIST,教师网络是 LeNet-5,学生网络是 LeNet-5-HALF。

下图6是消融实验的结果,一个三个损失函数:伪标签交叉熵损失,特征激活损失函数,信息熵损失函数。可以看到,如果一个都不用,就相当于是直接使用噪声蒸馏学生网络,则准确率是88.01%。使用不同的损失函数,精度如图,每一项损失都很重要。

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

图6:消融实验的结果

可视化结果

作者对教师和学生得到的卷积核做了可视化,如下图7所示。可以发现,我们的方法学到的学生网络和教师网络具有非常相似的结构,证明了本论文方法的有效性。

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

图7:卷积核可视化结果

作者还对训练得到的生成器 产生的图片进行了可视化,如下图8所示。注意生成的图像是没有 label 的,它们的类别是由教师网络的预测定义的。图8显示了每个类图像的平均值。虽然没有提供真实的图像,但生成的图像与训练图像具有相似的模式,这说明生成器可以以某种方式学习数据的分布。

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

图8:生成器输出的图片结果可视化

小结

DAFL 是一个新的无需训练数据的网络压缩方法,它的特点是: 待压缩网络看作一个固定的判别器 ,用生成器 输出的生成图片代替训练数据集进行训练,设计了伪标签交叉熵损失,特征激活损失函数,信息熵损失函数来训练生成器 ,使用生成数据结合蒸馏算法得到压缩后的网络。

2 适合大数据集的无需训练数据的网络压缩技术 (Arxiv 2021)

论文名称:Large-Scale Generative Data-Free Distillation

论文地址:

Large-Scale Generative Data-Free Distillation

https:///abs/2012.05578

谷歌提出的这个适合大数据集的无需训练数据的网络压缩技术基于上节介绍的 DAFL,解决的主要问题是 DAFL 没法在大数据集上使用的问题。

知识蒸馏方法是解决无标注模型压缩问题的一种重要的手段。但是正如前文所述,它的假设是在蒸馏阶段训练数据集是可得到的。但是实际情况是:在现实中的应用上,由于隐私因素的制约或者传输条件的限制,我们无法获得训练数据。比如:在医学图像场景中,用户不想让自己的照片 (数据) 被泄露;训练数据太多没办法传到云端,甚至是存储这些巨大量的数据集对于小型企业都是个难题;所以,使用常规的模型压缩办法在这些限制下无法被使用。

上节介绍的 DAFL 是一个新的无需训练数据的网络压缩方法,它的特点是: 待压缩网络看作一个固定的判别器 ,用生成器 输出的生成图片代替训练数据集进行训练,设计了伪标签交叉熵损失,特征激活损失函数,信息熵损失函数来训练生成器 ,使用生成数据结合蒸馏算法得到压缩后的网络。

本文就是基于 DAFL 实现的,解决了 DAFL 无法在大数据集 ImageNet 上使用的问题。

在这里需要再次强调的一点是:当前有许多用 GAN 来生成自然图像/高清图片/漫画图片/风格迁移/去雨/去噪/去模糊/去马赛克,等等等等各种生成任务。但是,它们无一例外在 GAN 模型的训练过程中都使用了大量的训练数据,而这在实际的业务条件下有时候是不被允许的。

本文的方法框架如下图9所示。本质上和 DAFL 的两个阶段是一致的,都是先用生成器 输出的生成图片代替训练数据集进行训练,然后使用生成数据结合蒸馏算法得到压缩后的网络。下面介绍作者是怎么做,能够解决了 DAFL 无法在大数据集 ImageNet 上使用的问题的。

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

图9:本文提出的无数据蒸馏方法框架

2.11 知识蒸馏方法获得学生网络

蒸馏算法最早由Hinton提出,待压缩网络 (教师网络) 为一个具有高准确率但参数很多的神经网络,初始化一个参数较少的学生网络,通过让学生网络的输出和教师网络相同,学生网络的准确率在教师的指导下得到提高。

从结构和参数的角度看,如上文所述,待压缩的大网络的结构和参数都是未知的,这就使得我们无法通过剪枝或者量化等经典的神经网络压缩方法进行模型压缩,我们唯一已知的就是待压缩的大网络的输入和输出接口。

令  和  分别代表教师和学生网络,则作者使用交叉熵损失来使得学生网络的输出符合教师网络的输出,具体的损失函数为:

模型压缩经典解读:解决训练数据问题,无需数据的神经网络压缩技术

式中, 指KL 散度损失函数,描述的是教师网络和学生网络的输出的差异, 指训练数据的分布,这里的训练数据和 DAFL 一样后续通过 GAN 来生成。通过引入教师学生算法,作者解决了生成图片没有标签的问题,并且可以在待压缩网络结构未知的情况下对其进行压缩。

2.12 通过 GAN 生成无标注的训练图片

从训练数据的角度看,在整个网络压缩的过程中,我们都没有任何给定的训练数据,在此情况下,神经网络的压缩变得十分困难。所以作者通过 GAN 来输出一些无标注的训练图片,以便于神经网络的压缩。生成对抗网络 (GAN) 是一种可以生成数据的方法,包含生成网络 与判别网络 ,生成网络希望输出和真实数据类似的图片来骗过判别器,判别网络通过判别生成图片和真实图片的真伪来帮助生成网络训练。

这个基本的流程和 DAFL 是一致的,但是本文的目标函数设计与 DAFL 有差别。

1) Inceptionism loss

免责声明:由于无法甄别是否为投稿用户创作以及文章的准确性,本站尊重并保护知识产权,根据《信息网络传播权保护条例》,如我们转载的作品侵犯了您的权利,请您通知我们,请将本侵权页面网址发送邮件到qingge@88.com,深感抱歉,我们会做删除处理。