/mdistiller

A Knowledge Distillation Toolbox. The official implementation of https://arxiv.org/abs/2203.08679

Primary LanguageJupyter Notebook

This repo is

(1) a PyTorch library that provides classical knowledge distillation algorithms on mainstream CV benchmarks,

(2) the official implementation of the CVPR-2022 paper: Decoupled Knowledge Distillation.

Decoupled Knowledge Distillation

Framework & Performance

Main Benchmark Results

On CIFAR-100:

Teacher
Student
ResNet56
ResNet20
ResNet110
ResNet32
ResNet32x4
ResNet8x4
WRN-40-2
WRN-16-2
WRN-40-2
WRN-40-1
VGG13
VGG8
KD 70.66 73.08 73.33 74.92 73.54 72.98
DKD 71.97 74.11 76.32 76.23 74.81 74.68
Teacher
Student
ResNet32x4
ShuffleNet-V1
WRN-40-2
ShuffleNet-V1
VGG13
MobileNet-V2
ResNet50
MobileNet-V2
ResNet32x4
MobileNet-V2
KD 74.07 74.83 67.37 67.35 74.45
DKD 76.45 76.70 69.71 70.35 77.07

On ImageNet:

Teacher
Student
ResNet34
ResNet18
ResNet50
MobileNet-V2
KD 71.03 70.50
DKD 71.70 72.05

MDistiller

Introduction

MDistiller supports the following distillation methods on CIFAR-100, ImageNet and MS-COCO:

Method Paper Link CIFAR-100 ImageNet MS-COCO
KD https://arxiv.org/abs/1503.02531
FitNet https://arxiv.org/abs/1412.6550
AT https://arxiv.org/abs/1612.03928
NST https://arxiv.org/abs/1707.01219
PKT https://arxiv.org/abs/1803.10837
KDSVD https://arxiv.org/abs/1807.06819
OFD https://arxiv.org/abs/1904.01866
RKD https://arxiv.org/abs/1904.05068
VID https://arxiv.org/abs/1904.05835
SP https://arxiv.org/abs/1907.09682
CRD https://arxiv.org/abs/1910.10699
ReviewKD https://arxiv.org/abs/2104.09044
DKD https://arxiv.org/abs/2203.08679

Installation

Environments:

  • Python 3.6
  • PyTorch 1.9.0
  • torchvision 0.10.0

Install the package:

sudo pip3 install -r requirements.txt
sudo python3 setup.py develop

Getting started

  1. Wandb as the logger
  • The registeration: https://wandb.ai/home.
  • If you don't want wandb as your logger, set CFG.LOG.WANDB as False at mdistiller/engine/cfg.py.
  1. Evaluation
  • You can evaluate the performance of our models or models trained by yourself.

  • Our models are at https://github.com/megvii-research/mdistiller/releases/tag/checkpoints, please download the checkpoints to ./download_ckpts

  • If test the models on ImageNet, please download the dataset at https://image-net.org/ and put them to ./data/imagenet

    # evaluate teachers
    python3 tools/eval.py -m resnet32x4 # resnet32x4 on cifar100
    python3 tools/eval.py -m ResNet34 -d imagenet # ResNet34 on imagenet
    
    # evaluate students
    python3 tools/eval.p -m resnet8x4 -c download_ckpts/dkd_resnet8x4 # dkd-resnet8x4 on cifar100
    python3 tools/eval.p -m MobileNetV2 -c download_ckpts/imgnet_dkd_mv2 -d imagenet # dkd-mv2 on imagenet
    python3 tools/eval.p -m model_name -c output/your_exp/student_best # your checkpoints
  1. Training on CIFAR-100
  • Download the cifar_teachers.tar at https://github.com/megvii-research/mdistiller/releases/tag/checkpoints and untar it to ./download_ckpts via tar xvf cifar_teachers.tar.

    # for instance, our DKD method.
    python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml
    
    # you can also change settings at command line
    python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml SOLVER.BATCH_SIZE 128 SOLVER.LR 0.1
  1. Training on ImageNet
  • Download the dataset at https://image-net.org/ and put them to ./data/imagenet

    # for instance, our DKD method.
    python3 tools/train.py --cfg configs/imagenet/r34_r18/dkd.yaml
  1. Training on MS-COCO
  1. Extension: Visualizations

Custom Distillation Method

  1. create a python file at mdistiller/distillers/ and define the distiller
from ._base import Distiller

class MyDistiller(Distiller):
    def __init__(self, student, teacher, cfg):
        super(MyDistiller, self).__init__(student, teacher)
        self.hyper1 = cfg.MyDistiller.hyper1
        ...

    def forward_train(self, image, target, **kwargs):
        # return the output logits and a Dict of losses
        ...
    # rewrite the get_learnable_parameters function if there are more nn modules for distillation.
    # rewrite the get_extra_parameters if you want to obtain the extra cost.
  ...
  1. regist the distiller in distiller_dict at mdistiller/distillers/__init__.py

  2. regist the corresponding hyper-parameters at mdistiller/engines/cfg.py

  3. create a new config file and test it.

Citation

If this repo is helpful for your research, please consider citing the paper:

@article{zhao2022dkd,
  title={Decoupled Knowledge Distllation},
  author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun},
  journal={arXiv preprint arXiv:2203.08679},
  year={2022}
}

License

MDistiller is released under the MIT license. See LICENSE for details.

Acknowledgement

  • Thanks for CRD and ReviewKD. We build this library based on the CRD's codebase and the ReviewKD's codebase.

  • Thanks Yiyu Qiu and Yi Shi for the code contribution during their internship in MEGVII Technology.

  • Thanks Xin Jin for the discussion about DKD.