This repo contains the original Pytorch implementation of the following paper:
On the Surprising Efficacy of Distillation as an Alternative to Pre-Training Small Models
Sean Farhat, Deming Chen
University of Illinois, Urbana-Champaign
which will appear at the 5th Practical ML for Low Resource Settings (PML4LRS) Workshop at ICLR 2024.
Go into paths.json
and edit the paths to point to where you'd like the datasets (data_path
, syn_data_path
) and checkpoints (cpt_path
) to be saved and loaded.
To obtain the desired teachers, we must first train them on the task of interest. This can be achieved via 3 methods, depending on how we wish for the teacher to learn.
- Scratch (FR): The network is randomly initialized and fully trained on the task end-to-end.
- Linear Probed (LP): The network is initialized with a feature backbone pre-trained on ImageNet. Then, only it's task-specific head is trained on the task.
- Full Finetuning (FT): The network is initialized with a feature backbone pre-trained on ImageNet. Then, it is fully trained on the task end-to-end. It's task-specific head uses a higher learning rate than the body.
Note: For our experiments, we take the ImageNet pre-trained weights from Pytorch's model hub.
To create these teachers, edit and use the appropriate scripts/train_(fr|lp|ft).sh
scripts.
All possible command line arguments can be found by running python train_(fr|lp|ft).py --help
.
These will create and save the best and last model checkpoints in <data_path>/<model>_<fr|lp|ft>_<optimizer>
Edit and run scripts/distill.sh
. In this script, we have several options to control the assistance process.
--dataset
chooses the task. Options:cifar100
,cifar10
,mit_indoor
,cub_2011
,caltech101
,dtd
--teacher_model
chooses which teacher model we use. Step 1 must be completed for this script to find the desired model. Options:resnet50
,vit-b-16
--teacher_init
chooses the initialization of the teacher. Options:fr
,lp
,ft
--student_model
chooses the student model which is initialized randomly. Options:mobilenetv2
orresnet18
--distill
chooses the distillation algorithm. Options:align_uniform
,crd
,kd
,srrl
All possible command line arguments can be found by running python distill.py --help
.
First, we have to generate the synthetic data. To do this, edit and run scripts/generate.sh
.
This will save the synthetic data in the syn_data_path
from paths.json
.
Then, edit and run scripts/gen_distill.sh
. It works similar to distill.sh
, with the addition of the following options:
All possible command line arguments can be found by running python generated_distill.py --help
.
--synset_size
chooses how much of the synthetic dataset we wish to use as a fraction of the training set size. Options:1x
,2x
(Note: enough synthetic images must be generated for these to work correctly.)--aug
enables image augmentations--aug_mode
chooses whether to apply a Singular or Multiple augmentations. Options:S
,M
Weights and Biases (wandb) integration is included for all scripts above. Assuming you have wandb
set up on your machine, simply add the --logging
flag to each script.
We have included a convenient --timing
flag for all scripts that will run the task for one epoch and report how long it took.
@misc{farhat2024surprising,
title={On the Surprising Efficacy of Distillation as an Alternative to Pre-Training Small Models},
author={Sean Farhat and Deming Chen},
year={2024},
eprint={2404.03263},
archivePrefix={arXiv},
primaryClass={cs.LG}
}