Knowledge Distillation with the Reused Teacher Classifier (CVPR-2022) https://arxiv.org/abs/2203.14001
This repository aims to provide a compact and easy-to-use implementation of several representative knowledge distillation approaches on standard image classification tasks (e.g., CIFAR100, ImageNet).
-
Generally, these KD approaches include a classification loss, a logit-level distillation loss, and an additional feature distillation loss. For fair comparison and ease of tuning, we fix the hyper-parameters for the first two loss terms as one throughout all experiments. (
--cls 1 --div 1
) -
The following approaches are currently supported by this toolbox, covering vanilla KD, feature-map distillation/feature-embedding distillation, instance-level distillation/pairwise-level distillation:
-
This toolbox is built on a open-source benchmark and our previous repository. The implementation of more KD approaches can be found there.
-
Computing Infrastructure:
- We use one NVIDIA GeForce RTX 2080Ti GPU for CIFAR-100 experiments. The PyTorch version is 1.0. We use four NVIDIA A40 GPUs for ImageNet experiments. The PyTorch version is 1.10.
- As for ImageNet, we use DALI for data loading and pre-processing.
-
The current codes have been reorganized and we have not tested them thoroughly. If you have any questions, please contact us without hesitation.
-
Please put the CIFAR-100 and ImageNet dataset in the
../data/
.
# CIFAR-100
python train_teacher.py --batch_size 64 --epochs 240 --dataset cifar100 --model resnet32x4 --learning_rate 0.05 --lr_decay_epochs 150,180,210 --weight_decay 5e-4 --trial 0 --gpu_id 0
# ImageNet
python train_teacher.py --batch_size 256 --epochs 120 --dataset imagenet --model ResNet18 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23333 --multiprocessing-distributed --dali gpu --trial 0
The pretrained teacher models used in our paper are provided in this link [GoogleDrive].
# CIFAR-100
python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill simkd --model_s resnet8x4 -c 0 -d 0 -b 1 --trial 0
# ImageNet
python train_student.py --path-t './save/teachers/models/ResNet50_vanilla/ResNet50_best.pth' --batch_size 256 --epochs 120 --dataset imagenet --model_s ResNet18 --distill simkd -c 0 -d 0 -b 1 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23444 --multiprocessing-distributed --dali gpu --trial 0
More scripts are provided in ./scripts
ResNet-8x4 | VGG-8 | ShuffleNetV2x1.5 | |
---|---|---|---|
Student | 73.09 | 70.46 | 74.15 |
KD | 74.42 | 72.73 | 76.82 |
FitNet | 74.32 | 72.91 | 77.12 |
AT | 75.07 | 71.90 | 77.51 |
SP | 74.29 | 73.12 | 77.18 |
VID | 74.55 | 73.19 | 77.11 |
CRD | 75.59 | 73.54 | 77.66 |
SRRL | 75.39 | 73.23 | 77.55 |
SemCKD | 76.23 | 75.27 | 79.13 |
SimKD (f=8) | 76.73 | 74.74 | 78.96 |
SimKD (f=4) | 77.88 | 75.62 | 79.48 |
SimKD (f=2) | 78.08 | 75.76 | 79.54 |
Teacher (ResNet-32x4) | 79.42 | 79.42 | 79.42 |
(Right) The top-1 test accuracy (%) (Student: ResNet-8x4, Teacher: ResNet-32x4).
If you find this repository useful, please consider citing the following paper:
@inproceedings{chen2022simkd,
title={Knowledge Distillation with the Reused Teacher Classifier},
author={Chen, Defang and Mei, Jian-Ping and Zhang, Hailin and Wang, Can and Feng, Yan and Chen, Chun},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={11933--11942},
year={2022}
}
@inproceedings{chen2021cross,
author = {Defang Chen and Jian{-}Ping Mei and Yuan Zhang and Can Wang and Zhe Wang and Yan Feng and Chun Chen},
title = {Cross-Layer Distillation with Semantic Calibration},
booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
pages = {7028--7036},
year = {2021},
}