首页|面向目标类别分类的无数据知识蒸馏方法

面向目标类别分类的无数据知识蒸馏方法

扫码查看
目的 目前,研究者们大多采用无数据蒸馏方法解决训练数据缺乏的问题.然而,现有的无数据蒸馏方法在实际应用场景中面临着模型收敛困难和学生模型紧凑性不足的问题,为了满足针对部分类别的模型训练需求,灵活选择教师网络目标类别知识,本文提出了一种新的无数据知识蒸馏方法:面向目标类别的掩码蒸馏(masked distil-lation for target classes,MDTC).方法 MDTC在生成器学习原始数据的批归一化参数分布的基础上,通过掩码阻断生成网络在梯度更新过程中非目标类别的梯度回传,训练一个仅生成目标类别样本的生成器,从而实现对教师模型中特定知识的准确提取;此外,MDTC将教师模型引入到生成网络中间层的特征学习过程,优化生成器的初始参数设置和参数更新策略,加速模型收敛.结果 在4个标准图像分类数据集上,设计13个子分类任务,评估MDTC在不同难度的子分类任务上的性能表现.实验结果表明,MDTC能准确高效地提取教师模型中的特定知识,不仅总体准确率优于主流的无数据蒸馏模型,而且训练耗时少.其中,40%以上学生模型的准确率甚至超过教师模型,最高提升了 3.6%.结论 本文方法的总体性能超越了现有无数据蒸馏模型,尤其是在简单样本分类任务的知识学习效率非常高,在提取知识类别占比较低的情况下,模型性能最优.
Data-free knowledge distillation for target class classification
Objective Knowledge distillation is a simple and effective method for compressing neural networks and has become a popular topic in model compression research.This method features a"teacher-student"architecture where a large network guides the training of a small network to improve its performance in application scenarios,indirectly achiev-ing network compression.In traditional methods,the training of the student model relies on the training data of the teacher,and the quality of the student model depends on the quality of the training data.When faced with data scarcity,these methods fail to produce satisfactory results.Data-free knowledge distillation successfully addresses the issue of lim-ited training data by introducing synthetic data.Such methods mainly synthesize training data by refining teacher network knowledge.For example,they use the intermediate representations of the teacher network for image inversion synthesis or employ the teacher network as a fixed discriminator to supervise the generator of synthetic images for training the student network.Compared with traditional methods,the training of data-free knowledge distillation does not rely on the original training data of the teacher network,which markedly expands the application scope of knowledge distillation.However,the training process may have a certain efficiency discount compared with traditional methods due to the need for additional synthetic training data.Furthermore,in practical applications,focus is often only provided on a few target classes.How-ever,existing data-free knowledge distillation methods encounter difficulties in selectively learning the knowledge of the tar-get class,especially when the number of teacher model classes is large,the model convergence is complex,and achieving sufficient compactness through the student model is difficult.Therefore,this paper proposes a novel data-free knowledge distillation method,namely masked distillation for target classes(MDTC).This method allows the student model to selec-tively learn the knowledge of target classes,maintaining good performance even in the presence of numerous classes in the teacher network.Compared to traditional methods,MDTC reduces the training difficulty and improves the training effi-ciency of data-free knowledge distillation.Method The MDTC method utilizes a generator to learn the batch-normalized parameter distribution of raw data and trains a generator that can generate target class samples by creating a mask to block the gradient backpropagation of non-target classes in the gradient update process of the generator.This method successfully extracts target knowledge from the teacher model while generating synthetic data that is similar to the original data.In addi-tion,MDTC introduces the teacher model into the feature learning process of the middle layer of the generator,supervises the training of the generator,and optimizes the initial parameter settings and parameter update strategies of the generator to accelerate the convergence of the model.The MDTC algorithm is divided into two stages:the first is the data synthesis stage,which fixes the student network and only updates the generated network.During the process of generating network updates,MDTC extracts three synthetic samples from the shallow,medium,and deep layers of the generator,inputs them into the teacher network for prediction,and updates the parameters of the generation network according to the feedback of the teacher network.When updating shallow and middle layer parameters,the other layer parameters of the generated net-work are fixed and updated separately for that layer.Finally,when updating the output layer of the generative network,the parameters of the entire generative network are updated to gradually guide the generator to learn and synthesize the target image.The second stage is the learning stage,in which the generation network is fixed and the synthetic samples are input-ted into the teacher and student networks for prediction.The target knowledge of the teacher is extracted by the mask,and Kullback-Leibler(KL)divergence is used to calculate the predicted output of the student network to update the student net-work.Result Four standard image classification datasets,namely,MNIST,SVHN,CIFAR10,and CIFAR100,are divided into 13 subclassification tasks by Pearson similarity calculation,including eight difficult tasks and five easy tasks.The performance of MDTC on subclassification tasks with different difficulties is evaluated by classification accuracy.The method is also compared with five mainstream data-free knowledge distillation methods and the vanilla KD method.Experi-mental results show that the proposed method outperforms the other mainstream data-free distillation models on 11 sub-tasks.Moreover,in MNIST1,MNIST2,SVHN1,SVHN3,CIFAR102,and CIFAR104(6 of the 13 subclassification tasks),the proposed method even surpasses the teacher model trained on the original data,achieving accuracy rates of 99.61%,99.46%,95.85%,95.80%,94.57%,and 95.00%,demonstrating a remarkable 3.6%improvement over the 91.40%accuracy of the teacher network in CIFAR104.Conclusion In this study,a novel data-free knowledge distillation method,MDTC,is proposed.The experimental results indicate that MDTC outperforms existing data-free distillation models over-all,especially in efficiently learning knowledge for easy sample classification tasks and when knowledge classes have a low proportion.The method displays excellent performance when extracting knowledge from a limited set of categories.

deep learningimage classificationmodel compressiondata-free knowledge distillationgenerators

谢奕涛、苏鹭梅、杨帆、陈宇涵

展开 >

厦门理工学院电气工程与自动化学院,厦门 361024

厦门大学自动化系,厦门 361102

深度学习 图像分类 模型压缩 无数据知识蒸馏 生成器

2024

中国图象图形学报
中国科学院遥感应用研究所,中国图象图形学学会 ,北京应用物理与计算数学研究所

中国图象图形学报

CSTPCD北大核心
影响因子:1.111
ISSN:1006-8961
年,卷(期):2024.29(11)