/Knowledge-Distillation-Zoo

Pytorch implementation of various Knowledge Distillation (KD) methods.

Primary LanguagePython

Knowledge-Distillation-Zoo

Pytorch implementation of various Knowledge Distillation (KD) methods.

This repository is a simple reference, mainly focuses on basic knowledge distillation/transfer methods. Thus many tricks and variations, such as step-by-step training, iterative training, ensemble of teachers, ensemble of KD methods, data-free, self-distillation, quantization etc. are not considered. Hope it is useful for your project or research.

I will update this repo regularly with new KD methods. If there some basic methods I missed, please contact with me.

Lists

Name Method Paper Link Code Link
Baseline basic model with softmax loss code
Logits mimic learning via regressing logits paper code
ST soft target paper code
AT attention transfer paper code
Fitnet hints for thin deep nets paper code
NST neural selective transfer paper code
PKT probabilistic knowledge transfer paper code
FSP flow of solution procedure paper code
FT factor transfer paper code
RKD relational knowledge distillation paper code
AB activation boundary paper code
SP similarity preservation paper code
Sobolev sobolev/jacobian matching paper code
BSS boundary supporting samples paper code
CC correlation congruence paper code
LwM learning without memorizing paper code
IRG instance relationship graph paper code
VID variational information distillation paper code
OFD overhaul of feature distillation paper code
AFD attention feature distillation paper code
CRD contrastive representation distillation paper code
DML deep mutual learning paper code
  • Note, there are some differences between this repository and the original papers:
    • For AT: I use the sum of absolute values with power p=2 as the attention.
    • For Fitnet: The training procedure is one stage without hint layer.
    • For NST: I employ polynomial kernel with d=2 and c=0.
    • For AB: Two-stage training, the first 50 epochs for initialization, the second stage only employs CE without ST.
    • For BSS: 80% epochs employ CE+BSS loss, the rest 20% only uses CE. In addition, warmup for the first 10 epochs.
    • For CC: For consistency, I only consider CC without instance congruence. Gaussian RBF kernel is employed because Bilinear Pool kernel is similar with PKT. I choose P=2 order Taylor of Gaussian RBF kernel. No special sampling strategy.
    • For LwM: I employ it after rb2 (middle conv layer) but not rb3 (last conv layer), because the base net is resnet with the end of GAP followed by a classifier. If after rb3, the grad-CAN has the same values across H and W in each channel.
    • For IRG: I only use one-to-one mode.
    • For VID: I set the hidden channel size to be same with the output channel size and remove BN in μ.
    • For AFD: I find the original implementation of attention is unstable, thus replace it with a SE block.
    • For DML: Just two nets are employed. Synchronous update to avoid multiple forwards.

Datasets

  • CIFAR10
  • CIFAR100

Networks

  • Resnet-20
  • Resnet-110

The networks are same with Tabel 6 in paper.

Training

  • Creating ./dataset directory and downloading CIFAR10/CIFAR100 in it.
  • Using the script example_train_script.sh to train various KD methods. You can simply specify the hyper-parameters listed in train_xxx.py or manually change them.
  • The hyper-parameters I used can be found in the training logs (code: ezed).
  • Some Notes:
    • Sobolev/LwM alone is unstable and may be used in conjunction with other KD methods.
    • BSS may occasionally destroy the training procedure, leading to poor results.
    • If not specified in the original papers, all the methods can be used on the middle feature maps or multiple feature maps are only employed after the last conv layer. It is simple to extend to multiple feature maps.
    • I assume the size (C, H, W) of features between teacher and student are the same. If not, you could employ 1*1 conv, linear or pooling to rectify them.

Results

  • The trained baseline models are used as teachers. For fair comparison, all the student nets have same initialization with the baseline models.
  • The initial models, trained models and training logs are uploaded here (code: ezed).
  • The trade-off parameter --lambda_kd and other hyper-parameters are not chosen carefully. Thus the following results do not reflect which method is better than the others.
  • Some relation based methods, e.g. PKT, RKD and CC, have less effectiveness on CIFAR100 dataset. It may be because there are more inter classes but less intra classes in one batch. You could increase the batch size, create memory bank or design advance batch sampling methods.
Teacher Student Name CIFAR10 CIFAR100
- resnet-20 Baseline 92.37% 68.92%
resnet-20 resnet-20 Logits 93.30% 70.36%
resnet-20 resnet-20 ST 93.12% 70.27%
resnet-20 resnet-20 AT 92.89% 69.70%
resnet-20 resnet-20 Fitnet 92.73% 70.08%
resnet-20 resnet-20 NST 92.79% 69.21%
resnet-20 resnet-20 PKT 92.50% 69.25%
resnet-20 resnet-20 FSP 92.76% 69.61%
resnet-20 resnet-20 FT 92.98% 69.90%
resnet-20 resnet-20 RKD 92.72% 69.48%
resnet-20 resnet-20 AB 93.04% 69.96%
resnet-20 resnet-20 SP 92.88% 69.85%
resnet-20 resnet-20 Sobolev 92.78% 69.39%
resnet-20 resnet-20 BSS 92.58% 69.96%
resnet-20 resnet-20 CC 93.01% 69.27%
resnet-20 resnet-20 LwM 92.80% 69.23%
resnet-20 resnet-20 IRG 92.77% 70.37%
resnet-20 resnet-20 VID 92.61% 69.39%
resnet-20 resnet-20 OFD 92.82% 69.93%
resnet-20 resnet-20 AFD 92.56% 69.63%
resnet-20 resnet-20 CRD 92.96% 70.33%
Teacher Student Name CIFAR10 CIFAR100
- resnet-20 Baseline 92.37% 68.92%
- resnet-110 Baseline 93.86% 73.15%
resnet-110 resnet-20 Logits 92.98% 69.78%
resnet-110 resnet-20 ST 92.82% 70.06%
resnet-110 resnet-20 AT 93.21% 69.28%
resnet-110 resnet-20 Fitnet 93.04% 69.81%
resnet-110 resnet-20 NST 92.83% 69.31%
resnet-110 resnet-20 PKT 93.01% 69.31%
resnet-110 resnet-20 FSP 92.78% 69.78%
resnet-110 resnet-20 FT 93.01% 69.49%
resnet-110 resnet-20 RKD 93.21% 69.36%
resnet-110 resnet-20 AB 92.96% 69.41%
resnet-110 resnet-20 SP 93.30% 69.45%
resnet-110 resnet-20 Sobolev 92.60% 69.23%
resnet-110 resnet-20 BSS 92.78% 69.71%
resnet-110 resnet-20 CC 92.98% 69.33%
resnet-110 resnet-20 LwM 92.52% 69.11%
resnet-110 resnet-20 IRG 93.13% 69.36%
resnet-110 resnet-20 VID 92.98% 69.49%
resnet-110 resnet-20 OFD 93.13% 69.81%
resnet-110 resnet-20 AFD 92.92% 69.60%
resnet-110 resnet-20 CRD 92.92% 70.80%
Teacher Student Name CIFAR10 CIFAR100
- resnet-110 Baseline 93.86% 73.15%
resnet-110 resnet-110 Logits 94.38% 74.89%
resnet-110 resnet-110 ST 94.59% 74.33%
resnet-110 resnet-110 AT 94.42% 74.64%
resnet-110 resnet-110 Fitnet 94.43% 73.63%
resnet-110 resnet-110 NST 94.43% 73.55%
resnet-110 resnet-110 PKT 94.35% 73.74%
resnet-110 resnet-110 FSP 94.39% 73.59%
resnet-110 resnet-110 FT 94.30% 74.72%
resnet-110 resnet-110 RKD 94.39% 73.78%
resnet-110 resnet-110 AB 94.63% 73.91%
resnet-110 resnet-110 SP 94.45% 74.07%
resnet-110 resnet-110 Sobolev 94.26% 73.14%
resnet-110 resnet-110 BSS 94.19% 73.87%
resnet-110 resnet-110 CC 94.49% 74.43%
resnet-110 resnet-110 LwM 94.19% 73.28%
resnet-110 resnet-110 IRG 94.44% 74.96%
resnet-110 resnet-110 VID 94.25% 73.63%
resnet-110 resnet-110 OFD 94.38% 74.11%
resnet-110 resnet-110 AFD 94.44% 73.90%
resnet-110 resnet-110 CRD 94.30% 75.44%
Net1 Net2 Name CIFAR10 CIFAR100
- resnet-20 baseline 92.37% 68.92%
- resnet-110 baseline 93.86% 73.15%
resnet20 resnet20 DML 93.07%/93.37% 70.39%/70.22%
resnet110 resnet20 DML 94.45%/92.92% 74.53%/70.29%
resnet110 resnet110 DML 94.74%/94.79% 74.72%/75.55%

Todo List

  • KDSVD (now has some bugs)
  • QuEST: Quantized Embedding Space for Transferring Knowledge
  • EEL: Learning an Evolutionary Embedding via Massive Knowledge Distillation
  • OnAdvFD: Feature-map-level Online Adversarial Knowledge Distillation
  • CS-KD: Regularizing Class-wise Predictions via Self-knowledge Distillation
  • PAD: Prime-Aware Adaptive Distillation
  • CD: Channel Distillation: Channel-Wise Attention for Knowledge Distillation
  • DCM: Knowledge Transfer via Dense Cross-Layer Mutual-Distillation

Requirements

  • python 3.7
  • pytorch 1.3.1
  • torchvision 0.4.2

Acknowledgements

This repo is partly based on the following repos, thank the authors a lot.

If you employ the listed KD methods in your research, please cite the corresponding papers.