/cGAN-KD

A unified cGAN-based knowledge distillation method

Primary LanguagePython

Distilling and Transferring Knowledge via cGAN-generated Samples for Image Classification and Regression

This repository provides the source codes for the experiments in our paper.
If you use this code, please cite


@misc{ding2021distilling,
      title={Distilling and Transferring Knowledge via cGAN-generated Samples for Image Classification and Regression}, 
      author={Xin Ding and Z. Jane Wang and Zuheng Xu and Yongwei Wang and William J. Welch},
      year={2021},
      eprint={2104.03164},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

The workflow of cGAN-KD.

Evolution of fake samples' distributions and datasets.


Requirements

argparse>=1.1, h5py>=2.10.0, matplotlib>=3.2.1, numpy>=1.18.5, Pillow>=7.0.0, python=3.8.5, torch>=1.5.0, torchvision>=0.6.0, tqdm>=4.46.1


Datasets (h5 files) and necessary checkpoints

Download and unzip cGAN-KD_data_and_ckpts.7z:
https://1drv.ms/u/s!Arj2pETbYnWQsswK03PbX-w5YQg2AQ?e=eI59gT

Then, transfer files or folders in cGAN-KD_data_and_ckpts to this github repo as follows.

Put ./C-X0K/CIFAR10_trainset_X0000_seed_2020.h5 at ./CIFAR_X0K/cGAN-based_KD/data/.
Put ./C-X0K/C10_2020.hdf5 at ./CIFAR_X0K/BigGAN/data/.
Put ./C-X0K/UDA_pretrained_teachers/*.pth at ./CIFAR_X0K/Distiller/pretrained/.
Put ./C-X0K/ckpt_BigGAN_cifar10_ntrain_X0000_seed_2020 at ./CIFAR_X0K/cGAN-based_KD/Output_CIFAR10/saved_models/.
X stands for 5, 2, 1, representing C-50K, C-20K, and C-10K respectively.

Put ./Tiny-ImageNet/tiny-imagenet-200.h5 at ./Tiny-ImageNet/cGAN-based_KD/data/.
Put ./Tiny-ImageNet/UDA_pretrained_teachers/*.pth at ./Tiny-ImageNet/Distiller/pretrained/.
Put ./Tiny-ImageNet/BigGAN_weights at ./Tiny-ImageNet/cGAN-based_KD/output/saved_models/.

Put ./RC-49/dataset at ./RC-49.
Put ./RC-49/output at ./RC-49/CcGAN-based_KD.
The output/saved_models folder includes the pretrained CcGAN (SNGAN arch), label embedding networks, a sparse autoencoder for feature extraction, a MLP-5 for conditional density ratio estimation, and the teacher model VGG16.

Put ./UTKFace/dataset at ./UTKFace.
Put ./UTKFace/output at ./UTKFace/CcGAN-based_KD/.
The output/saved_models folder includes the pretrained CcGAN (SAGAN arch), label embedding networks, a sparse autoencoder for feature extraction, a MLP-5 for conditional density ratio estimation, and the teacher model VGG16.


Sample Usage

CIFAR-10

The codes for C-50K, C-20K, and C-10K are in ./CIFAR/CIFAR_50K, ./CIFAR/CIFAR_20K, and ./CIFAR/CIFAR_10K respectively. We only take C-50K as an example to show how to run the experiment.

BigGAN training

The implementation of BigGAN is mainly based on [3].
Run ./CIFAR/CIFAR_50K/BigGAN/scripts/launch_cifar10_ema.sh for C-50K (50,000 samples).

Checkpoints of BigGAN used in our experiments are in cGAN-KD_data_and_ckpts.7z.

Fake data generation

Run ./CIFAR/CIFAR_50K/cGAN-based_KD/scripts/run_gen_synt_data.sh for C-50K (50,000 samples).

filtering_threshold in run_gen_synt_data.sh controls the filtering threshold.
NFAKE_PER_CLASS in run_gen_synt_data.sh controls the number of fake images generated for each class.
Generated fake data are stored in ./CIFAR/CIFAR_50K/cGAN-based_KD/data in the h5 format.
Note that before the fake data generation, make sure you train the teacher model first via running ./CIFAR/CIFAR_50K/cGAN-based_KD/scripts/run_nokd.sh.

Grid search for selecting the optimal rho

Run ./CIFAR/CIFAR_50K/cGAN-based_KD/scripts/run_rho_selection.sh
FAKE_DATASET_NAME specifies the fake dataset generated with different rho, which in the format of 'BigGAN_..._nfake_xxx'.
Modify FAKE_DATASET_NAME to test the performance of student models under different rho's.

NOKD, BLKD, TAKD, cGAN-KD, cGAN-KD+BLKD, cGAN-KD+TAKD

NOKD: run ./CIFAR/CIFAR_50K/cGAN-based_KD/scripts/run_nokd.sh
BLKD: run ./CIFAR/CIFAR_50K/cGAN-based_KD/scripts/run_blkd.sh
TAKD: run ./CIFAR/CIFAR_50K/cGAN-based_KD/scripts/run_takd.sh
cGAN-KD: run ./CIFAR/CIFAR_50K/cGAN-based_KD/scripts/run_nokd_fake.sh
cGAN-KD+BLKD: run ./CIFAR/CIFAR_50K/cGAN-based_KD/scripts/run_blkd_fake.sh
cGAN-KD+TAKD: run ./CIFAR/CIFAR_50K/cGAN-based_KD/scripts/run_takd_fake.sh

SSKD and cGAN-KD+SSKD

NOKD: run ./CIFAR/CIFAR_50K/SSKD/scripts/run_nokd.sh
SSKD: run ./CIFAR/CIFAR_50K/SSKD/scripts/run_SSKD.sh
cGAN-KD+SSKD: run ./CIFAR/CIFAR_50K/SSKD/scripts/run_SSKD+fake.sh

BLKD+UDA and cGAN-KD+BLKD+UDA

NOKD: run ./CIFAR/CIFAR_50K/Distiller/scripts/run_nokd.sh
BLKD+UDA: run ./CIFAR/CIFAR_50K/Distiller/scripts/run_blkd+uda.sh
cGAN-KD+BLKD+UDA: run ./CIFAR/CIFAR_50K/Distiller/scripts/run_blkd+uda+fake.sh

Tiny-ImageNet

BigGAN training

We provide the checkpoints of BigGAN used in the experiment in cGAN-KD_data_and_ckpts.7z.

Fake data generation

Run ./Tiny-ImageNet/cGAN-based_KD/scripts/run_gen_synt_data.sh for C-50K (50,000 samples).
filtering_threshold in run_gen_synt_data.sh controls the filtering threshold.
NFAKE_PER_CLASS in run_gen_synt_data.sh controls the number of fake images generated for each class.

Grid search for selecting the optimal rho

Run ./Tiny-ImageNet/cGAN-based_KD/scripts/run_rho_selection.sh
FAKE_DATASET_NAME specifies the fake dataset generated with different rho, which in the format of 'BigGAN_..._nfake_xxx'.
Modify FAKE_DATASET_NAME to test the performance of student models under different rho's.

NOKD, BLKD, TAKD, cGAN-KD, cGAN-KD+BLKD, cGAN-KD+TAKD

NOKD: run ./Tiny-ImageNet/cGAN-based_KD/scripts/run_nokd.sh
BLKD: run ./Tiny-ImageNet/cGAN-based_KD/scripts/run_blkd.sh
TAKD: run ./Tiny-ImageNet/cGAN-based_KD/scripts/run_takd.sh
cGAN-KD: run ./Tiny-ImageNet/cGAN-based_KD/scripts/run_nokd_fake.sh
cGAN-KD+BLKD: run ./Tiny-ImageNet/cGAN-based_KD/scripts/run_blkd_fake.sh
cGAN-KD+TAKD: run ./Tiny-ImageNet/cGAN-based_KD/scripts/run_takd_fake.sh

SSKD and cGAN-KD+SSKD

NOKD: run ./Tiny-ImageNet/SSKD/scripts/run_nokd.sh
SSKD: run ./Tiny-ImageNet/SSKD/scripts/run_SSKD.sh
cGAN-KD+SSKD: run ./Tiny-ImageNet/SSKD/scripts/run_SSKD+fake.sh

BLKD+UDA and cGAN-KD+BLKD+UDA

NOKD: run ./Tiny-ImageNet/Distiller/scripts/run_nokd.sh
BLKD+UDA: run ./Tiny-ImageNet/Distiller/scripts/run_blkd+uda.sh
cGAN-KD+BLKD+UDA: run ./Tiny-ImageNet/Distiller/scripts/run_blkd+uda+fake.sh

RC-49

We only take R-25 as an example to show how to run the experiment.

CcGAN training and fake data generation

The implementation of CcGAN is mainly based on [1] and [2] but we change loss function to hinge loss.

Run ./RC49/CcGAN-based_KD/scripts/run_gen_r25.sh for R-25 (25 images per angle).

filtering_threshold controls the filtering threshold.

Grid search for selecting the optimal rho

Run ./RC49/CcGAN-based_KD/scripts/run_rho_selection_r25.sh

FAKE_DATASET_NAME specifies the fake dataset's name, which in the 'GANNAME_..._nfake_xxx'.
Modify FAKE_DATASET_NAME to test the performance of student models under different rho's.

NOKD and cGAN-KD

NOKD: run ./RC49/CcGAN-based_KD/scripts/run_nokd_r25.sh
cGAN-KD: run ./RC49/CcGAN-based_KD/scripts/run_nokd+fake_r25.sh

UTKFace

CcGAN training and fake data generation

The implementation of CcGAN is mainly based on [1] and [2] but we use the SAGAN architecture and hinge loss.

Run ./UTKFace/CcGAN-based_KD/scripts/run_gen.sh.
filtering_threshold controls the filtering threshold.

Grid search for selecting the optimal rho

Run ./UTKFace/CcGAN-based_KD/scripts/run_rho_selection.sh
FAKE_DATASET_NAME specifies the fake dataset's name, which in the 'GANNAME_..._nfake_xxx'.
Modify FAKE_DATASET_NAME to test the performance of student models under different rho's.

NOKD and cGAN-KD

NOKD: run ./UTKFace/CcGAN-based_KD/scripts/run_nokd.sh
cGAN-KD: run ./UTKFace/CcGAN-based_KD/scripts/run_nokd_fake.sh


Some Results

  • CIFAR-10

  • RC-49


References

[1] X. Ding, Y. Wang, Z. Xu, W. J. Welch, and Z. J. Wang, “CcGAN: Continuous conditional generative adversarial networks for image generation,” in International Conference on Learning Representations, 2021.
[2] X. Ding, Y. Wang, Z. Xu, W. J. Welch, and Z. J. Wang, “Continuous conditional generative adversarial networks for image generation: Novel losses and label input mechanisms,” arXiv preprint arXiv:2011.07466, 2020. https://github.com/UBCDingXin/improved_CcGAN
[3] https://github.com/ajbrock/BigGAN-PyTorch
[4] Ding, Xin, et al. "Efficient Subsampling for Generating High-Quality Images from Conditional Generative Adversarial Networks." arXiv preprint arXiv:2103.11166 (2021). https://github.com/UBCDingXin/cDRE-based_Subsampling_cGANS