Knowledge-Distillation-Zoo
Pytorch implementation of various Knowledge Distillation methods.
This repository is a simple reference, thus many tricks, such as step-by-step training, iterative training, ensemble of teachers, etc. are not considered.
Note, there are some differences between this repository and the original paper:
For fitnet
: the training procedure is one stage without hint layer.
For at
: I use the sum of absolute values with power p=2 as the attention.
For nst
: I use squared mmd matching.
For dml
: just two nets are employed.
The networks are same with Tabel 6 in paper .
Creating ./dataset
directory and downloading CIFAR10/CIFAR100 in it.
Using the train script, simply specifying the parameters listed in train_xxx.py
as a flag or manually changing them.
The parameters I used can be found in the training logs .
For baseline
python train_baseline.py
--data_name=cifar10/cifar100 \
--net_name=resnet20/resnet110 \
--num_class=10/100
For logits,st,fitnet,at,fsp,nst,pkt,ft
python train_xxx.py
--s_init=/path/to/your/student_initial_model \
--t_model=/path/to/your/teacher_model \
--data_name=cifar10/cifar100 \
--t_name=resnet20/resnet110 \
--s_name=resnet20/resnet110 \
--num_class=10/100
python train_dml.py
--net1_init=/path/to/your/net1_initial_model \
--net2_init=/path/to/your/net2_initial_model \
--data_name=cifar10/cifar100 \
--net1_name=resnet20/resnet110 \
--net2_name=resnet20/resnet110 \
--num_class=10/100
The trained baseline models are used as teachers. For fair comparison, all the student nets have same initialization with the baseline models.
The initial models, trained models and training logs are uploaded here .
The loss trade-off parameters --lambda_xxx
are not chosen carefully. Thus the following results do not reflect which method is better than the others.
Teacher
Student
Method
CIFAR10
CIFAR100
-
resnet-20
baseline
92.18%
68.33%
resnet-20
resnet-20
logits
93.01%
69.87%
resnet-20
resnet-20
st
92.54%
69.92%
resnet-20
resnet-20
fitnet
92.48%
69.05%
resnet-20
resnet-20
at
92.58%
68.56%
resnet-20
resnet-20
fsp
92.57%
69.10%
resnet-20
resnet-20
nst
92.35%
68.35%
resnet-20
resnet-20
pkt
92.83%
68.83%
resnet-20
resnet-20
ft
92.92%
68.86%
Teacher
Student
Method
CIFAR10
CIFAR100
-
resnet-20
baseline
92.18%
68.33%
-
resnet-110
baseline
94.04%
72.65%
resnet-110
resnet-20
logits
93.33%
69.94%
resnet-110
resnet-20
st
92.82%
69.45%
resnet-110
resnet-20
fitnet
92.55%
69.68%
resnet-110
resnet-20
at
92.84%
69.05%
resnet-110
resnet-20
fsp
92.83%
69.38%
resnet-110
resnet-20
nst
92.51%
68.41%
resnet-110
resnet-20
pkt
92.95%
69.04%
resnet-110
resnet-20
ft
93.20%
69.45%
Teacher
Student
Method
CIFAR10
CIFAR100
-
resnet-110
baseline
94.04%
72.65%
resnet-110
resnet-110
logits
94.48%
74.72%
resnet-110
resnet-110
st
94.30%
74.29%
resnet-110
resnet-110
fitnet
94.58%
73.21%
resnet-110
resnet-110
at
94.34%
73.81%
resnet-110
resnet-110
fsp
94.29%
73.71%
resnet-110
resnet-110
nst
94.27%
72.84%
resnet-110
resnet-110
pkt
94.76%
73.73%
resnet-110
resnet-110
ft
94.46%
73.41%
Net1
Net2
Method
CIFAR10
CIFAR100
-
resnet-20
baseline
92.18%
68.33%
-
resnet-110
baseline
94.04%
72.65%
resnet20
resnet20
dml
92.99%/92.81%
70.30%/70.19%
resnet110
resnet20
dml
94.52%/92.72%
75.25%/70.26%
resnet110
resnet110
dml
94.92%/94.46%
74.70%/74.91%
python 2.7
pytorch 1.0.0
torchvision 0.2.1