This is the code and data associated with our ICML 2021 paper.
https://icml.cc/virtual/2021/poster/10257
https://arxiv.org/abs/2106.03310
NumPy
PyTorch (tested on 1.9.1)
Torchvision
.
├── data # datasets downloaded or saved here
├── generated_samples # generated pseudo samples saved here
├── labels # generated soft labels saved here
├── models # trained teacher models saved here
├── train_model_ce.py # standard training (a teacher) with cross-entropy loss
├── models.py # all model definitions and wrapper for sample robustness calculation
├── get_soft_labels.py # calculate soft labels with sample robustness
├── sample_robustness.py # methods for calculating sample robustness
├── train_model_kd.py # training with KD
├── get_pseudo_samples.py # generate pseudo samples with ZSDB3KD
├── untargeted_mbd.py # calculate the untargeted distances from a noise input to boundary
├── README.MD # readme file
Train a LeNet5 teacher with the MNIST dataset:
python train_model_ce.py --mode teacher --dataset MNIST --architecture LeNet5
Train a LeNet5 teacher with the FashionMNIST dataset:
python train_model_ce.py --mode teacher --dataset FASHIONMNIST --architecture LeNet5
Train a AlexNet teacher with the CIFAR10 dataset:
python train_model_ce.py --mode teacher --dataset CIFAR10 --architecture AlexNet
PS: train_model_ce.py can also be used for training/evaluating the student models (e.g., LeNet5-half, LeNet5-fifth, etc.) with the cross-entropy loss only.
sd: sample distance; bd: boundary distance; mbd: minimal boundary distance
LeNet-5 with MNIST:
python get_soft_labels.py --dataset MNIST --sr_mode {sd/bd/mbd} --model_dir ./models/teacher_LeNet5_MNIST
LeNet-5 with FashionMNIST:
python get_soft_labels.py --dataset FASHIONMNIST --sr_mode {sd/bd/mbd} --model_dir ./models/teacher_LeNet5_FASHIONMNIST
AlexNet with CIFAR10:
python get_soft_labels.py --dataset CIFAR10 --sr_mode {sd/bd/mbd} --model_dir ./models/teacher_AlexNet_CIFAR10
python train_model_kd.py --dataset {MNIST/FASHIONMNIST/CIFAR10} --mode {small/tiny} --logits PATH_OF_SAVED_LOGITS
python get_pseudo_samples.py --dataset MNIST --batch_size 200 --model_dir PATH_OF_SAVED_TEACHER_MODEL
The generated pseudo samples can be used for getting the soft labels with the 2nd and 3rd steps to test ZSDB3KD.
If you found this code useful, please consider citing the following work. Thank you!
@inproceedings{wang2021zero,
title={Zero-shot knowledge distillation from a decision-based black-box model},
author={Wang, Zi},
booktitle={International Conference on Machine Learning},
pages={10675--10685},
year={2021},
organization={PMLR}
}