Pytorch implementation of various Knowledge Distillation (KD) methods.
This repository is a simple reference, mainly focuses on basic knowledge distillation/transfer methods. Thus many tricks and variations, such as step-by-step training, iterative training, ensemble of teachers, ensemble of KD methods, data-free, self-distillation, quantization etc. are not considered. Hope it is useful for your project or research.
I will update this repo regularly with new KD methods. If there some basic methods I missed, please contact with me.
Name | Method | Paper Link | Code Link |
---|---|---|---|
Baseline | basic model with softmax loss | — | code |
Logits | mimic learning via regressing logits | paper | code |
ST | soft target | paper | code |
AT | attention transfer | paper | code |
Fitnet | hints for thin deep nets | paper | code |
NST | neural selective transfer | paper | code |
PKT | probabilistic knowledge transfer | paper | code |
FSP | flow of solution procedure | paper | code |
FT | factor transfer | paper | code |
RKD | relational knowledge distillation | paper | code |
AB | activation boundary | paper | code |
SP | similarity preservation | paper | code |
Sobolev | sobolev/jacobian matching | paper | code |
BSS | boundary supporting samples | paper | code |
CC | correlation congruence | paper | code |
LwM | learning without memorizing | paper | code |
IRG | instance relationship graph | paper | code |
VID | variational information distillation | paper | code |
OFD | overhaul of feature distillation | paper | code |
AFD | attention feature distillation | paper | code |
CRD | contrastive representation distillation | paper | code |
DML | deep mutual learning | paper | code |
- Note, there are some differences between this repository and the original papers:
- For
AT
: I use the sum of absolute values with power p=2 as the attention. - For
Fitnet
: The training procedure is one stage without hint layer. - For
NST
: I employ polynomial kernel with d=2 and c=0. - For
AB
: Two-stage training, the first 50 epochs for initialization, the second stage only employs CE without ST. - For
BSS
: 80% epochs employ CE+BSS loss, the rest 20% only uses CE. In addition, warmup for the first 10 epochs. - For
CC
: For consistency, I only consider CC without instance congruence. Gaussian RBF kernel is employed because Bilinear Pool kernel is similar with PKT. I choose P=2 order Taylor of Gaussian RBF kernel. No special sampling strategy. - For
LwM
: I employ it after rb2 (middle conv layer) but not rb3 (last conv layer), because the base net is resnet with the end of GAP followed by a classifier. If after rb3, the grad-CAN has the same values across H and W in each channel. - For
IRG
: I only use one-to-one mode. - For
VID
: I set the hidden channel size to be same with the output channel size and remove BN in μ. - For
AFD
: I find the original implementation of attention is unstable, thus replace it with a SE block. - For
DML
: Just two nets are employed. Synchronous update to avoid multiple forwards.
- For
- CIFAR10
- CIFAR100
- Resnet-20
- Resnet-110
The networks are same with Tabel 6 in paper.
- Creating
./dataset
directory and downloading CIFAR10/CIFAR100 in it. - Using the script
example_train_script.sh
to train various KD methods. You can simply specify the hyper-parameters listed intrain_xxx.py
or manually change them. - The hyper-parameters I used can be found in the training logs (code: ezed).
- Some Notes:
- Sobolev/LwM alone is unstable and may be used in conjunction with other KD methods.
- BSS may occasionally destroy the training procedure, leading to poor results.
- If not specified in the original papers, all the methods can be used on the middle feature maps or multiple feature maps are only employed after the last conv layer. It is simple to extend to multiple feature maps.
- I assume the size (C, H, W) of features between teacher and student are the same. If not, you could employ 1*1 conv, linear or pooling to rectify them.
- The trained baseline models are used as teachers. For fair comparison, all the student nets have same initialization with the baseline models.
- The initial models, trained models and training logs are uploaded here (code: ezed).
- The trade-off parameter
--lambda_kd
and other hyper-parameters are not chosen carefully. Thus the following results do not reflect which method is better than the others. - Some relation based methods, e.g. PKT, RKD and CC, have less effectiveness on CIFAR100 dataset. It may be because there are more inter classes but less intra classes in one batch. You could increase the batch size, create memory bank or design advance batch sampling methods.
Teacher | Student | Name | CIFAR10 | CIFAR100 |
- | resnet-20 | Baseline | 92.37% | 68.92% |
resnet-20 | resnet-20 | Logits | 93.30% | 70.36% |
resnet-20 | resnet-20 | ST | 93.12% | 70.27% |
resnet-20 | resnet-20 | AT | 92.89% | 69.70% |
resnet-20 | resnet-20 | Fitnet | 92.73% | 70.08% |
resnet-20 | resnet-20 | NST | 92.79% | 69.21% |
resnet-20 | resnet-20 | PKT | 92.50% | 69.25% |
resnet-20 | resnet-20 | FSP | 92.76% | 69.61% |
resnet-20 | resnet-20 | FT | 92.98% | 69.90% |
resnet-20 | resnet-20 | RKD | 92.72% | 69.48% |
resnet-20 | resnet-20 | AB | 93.04% | 69.96% |
resnet-20 | resnet-20 | SP | 92.88% | 69.85% |
resnet-20 | resnet-20 | Sobolev | 92.78% | 69.39% |
resnet-20 | resnet-20 | BSS | 92.58% | 69.96% |
resnet-20 | resnet-20 | CC | 93.01% | 69.27% |
resnet-20 | resnet-20 | LwM | 92.80% | 69.23% |
resnet-20 | resnet-20 | IRG | 92.77% | 70.37% |
resnet-20 | resnet-20 | VID | 92.61% | 69.39% |
resnet-20 | resnet-20 | OFD | 92.82% | 69.93% |
resnet-20 | resnet-20 | AFD | 92.56% | 69.63% |
resnet-20 | resnet-20 | CRD | 92.96% | 70.33% |
Teacher | Student | Name | CIFAR10 | CIFAR100 |
- | resnet-20 | Baseline | 92.37% | 68.92% |
- | resnet-110 | Baseline | 93.86% | 73.15% |
resnet-110 | resnet-20 | Logits | 92.98% | 69.78% |
resnet-110 | resnet-20 | ST | 92.82% | 70.06% |
resnet-110 | resnet-20 | AT | 93.21% | 69.28% |
resnet-110 | resnet-20 | Fitnet | 93.04% | 69.81% |
resnet-110 | resnet-20 | NST | 92.83% | 69.31% |
resnet-110 | resnet-20 | PKT | 93.01% | 69.31% |
resnet-110 | resnet-20 | FSP | 92.78% | 69.78% |
resnet-110 | resnet-20 | FT | 93.01% | 69.49% |
resnet-110 | resnet-20 | RKD | 93.21% | 69.36% |
resnet-110 | resnet-20 | AB | 92.96% | 69.41% |
resnet-110 | resnet-20 | SP | 93.30% | 69.45% |
resnet-110 | resnet-20 | Sobolev | 92.60% | 69.23% |
resnet-110 | resnet-20 | BSS | 92.78% | 69.71% |
resnet-110 | resnet-20 | CC | 92.98% | 69.33% |
resnet-110 | resnet-20 | LwM | 92.52% | 69.11% |
resnet-110 | resnet-20 | IRG | 93.13% | 69.36% |
resnet-110 | resnet-20 | VID | 92.98% | 69.49% |
resnet-110 | resnet-20 | OFD | 93.13% | 69.81% |
resnet-110 | resnet-20 | AFD | 92.92% | 69.60% |
resnet-110 | resnet-20 | CRD | 92.92% | 70.80% |
Teacher | Student | Name | CIFAR10 | CIFAR100 |
- | resnet-110 | Baseline | 93.86% | 73.15% |
resnet-110 | resnet-110 | Logits | 94.38% | 74.89% |
resnet-110 | resnet-110 | ST | 94.59% | 74.33% |
resnet-110 | resnet-110 | AT | 94.42% | 74.64% |
resnet-110 | resnet-110 | Fitnet | 94.43% | 73.63% |
resnet-110 | resnet-110 | NST | 94.43% | 73.55% |
resnet-110 | resnet-110 | PKT | 94.35% | 73.74% |
resnet-110 | resnet-110 | FSP | 94.39% | 73.59% |
resnet-110 | resnet-110 | FT | 94.30% | 74.72% |
resnet-110 | resnet-110 | RKD | 94.39% | 73.78% |
resnet-110 | resnet-110 | AB | 94.63% | 73.91% |
resnet-110 | resnet-110 | SP | 94.45% | 74.07% |
resnet-110 | resnet-110 | Sobolev | 94.26% | 73.14% |
resnet-110 | resnet-110 | BSS | 94.19% | 73.87% |
resnet-110 | resnet-110 | CC | 94.49% | 74.43% |
resnet-110 | resnet-110 | LwM | 94.19% | 73.28% |
resnet-110 | resnet-110 | IRG | 94.44% | 74.96% |
resnet-110 | resnet-110 | VID | 94.25% | 73.63% |
resnet-110 | resnet-110 | OFD | 94.38% | 74.11% |
resnet-110 | resnet-110 | AFD | 94.44% | 73.90% |
resnet-110 | resnet-110 | CRD | 94.30% | 75.44% |
Net1 | Net2 | Name | CIFAR10 | CIFAR100 |
- | resnet-20 | baseline | 92.37% | 68.92% |
- | resnet-110 | baseline | 93.86% | 73.15% |
resnet20 | resnet20 | DML | 93.07%/93.37% | 70.39%/70.22% |
resnet110 | resnet20 | DML | 94.45%/92.92% | 74.53%/70.29% |
resnet110 | resnet110 | DML | 94.74%/94.79% | 74.72%/75.55% |
- KDSVD (now has some bugs)
- QuEST: Quantized Embedding Space for Transferring Knowledge
- EEL: Learning an Evolutionary Embedding via Massive Knowledge Distillation
- OnAdvFD: Feature-map-level Online Adversarial Knowledge Distillation
- CS-KD: Regularizing Class-wise Predictions via Self-knowledge Distillation
- PAD: Prime-Aware Adaptive Distillation
- CD: Channel Distillation: Channel-Wise Attention for Knowledge Distillation
- DCM: Knowledge Transfer via Dense Cross-Layer Mutual-Distillation
- python 3.7
- pytorch 1.3.1
- torchvision 0.4.2
This repo is partly based on the following repos, thank the authors a lot.
- HobbitLong/RepDistiller
- bhheo/BSS_distillation
- clovaai/overhaul-distillation
- passalis/probabilistic_kt
- lenscloth/RKD
If you employ the listed KD methods in your research, please cite the corresponding papers.