This is the code of CVPR 2020 oral paper "Distilling Cross-Task Knowledge via Relationship Matching". If you use any content of this repo for your work, please cite the following bib entry:
@inproceedings{ye2020refilled,
author = {Han-Jia Ye and
Su Lu and
De-Chuan Zhan},
title = {Cross-Task Knowledge Distillation via Relationship Matching},
booktitle = {Computer Vision and Pattern Recognition (CVPR)},
year = {2020}
}
It is intuitive to take advantage of the learning experience from related pre-trained models to facilitate model training in the current task. Different from fine-tuning or parameter regularization, knowledge distillation/knowledge reuse extracts kinds of dark knowledge/privileged information from a fixed strong model (a.k.a. "teacher"), and enrich the target model (a.k.a. "student") training with more signals. Owing to the strong correspondence between classifier and class,it is difficult to reuse the classification knowledge from a cross-task teacher model.
We propose the RElationship FacIlitated Local cLassifiEr Distillation (REFILLED), which decomposes the knowledge distillation flow for embedding and the top-layer classifier respectively. REFILLED contains two stages. First, the discriminative ability of features is emphasized. For those hard triplets determined by the embedding of the student model, the teacher’s comparison between them is used as the soft supervision. A teacher enhances the discriminative embedding of the student by specifying the proportion for each object how much a dissimilar impostor should be far away from a target nearest neighbor. Furthermore, the teacher constructs the soft supervision for each instance by measuring its similarity to a local center. By matching the "instance-label" predictions across models, the cross-task teacher improves the learning efficacy of the student.
We further improve our proposed method by extending the dimension of matched tuple probabilities in stage1 and replacing local class centers with global class centers in stage2.
REFILLED can be used in several applications, e.g., standard knowledge distillation, cross-task knowledge distillation and middle-shot learning. Standard knowledge distillation is widely used and we show the results under this setting below. Experiment results of cross-task knowledge distillation and middle-shot learning can be found in the paper.
(depth, width) | (40,2) | (16,2) | (40,1) | (16,1) |
---|---|---|---|---|
Teacher | 76.04 | |||
Student | 76.04 | 70.15 | 71.53 | 66.30 |
Paper Results | 77.49 | 74.01 | 72.72 | 67.56 |
REFILLED after stage1 (paper) | 55.47 | 50.14 | 45.04 | 38.06 |
REFILLED after stage1 (new) | 62.12 | 53.86 | 52.71 | 44.33 |
Results after stage1 are accuracies of NCM classifier, rather than NMI of clustering.
width multiplier | 1.00 | 0.75 | 0.50 | 0.25 |
---|---|---|---|---|
Teacher | 76.19 | |||
Student | 76.19 | 74.49 | 72.68 | 68.80 |
Paper Results | 78.95 | 78.01 | 76.11 | 73.42 |
REFILLED after stage1 (paper) | 36.56 | 33.00 | 29.60 | 19.10 |
REFILLED after stage1 (new) | 38.47 | 36.95 | 33.71 | 25.34 |
Results after stage1 are accuracies of NCM classifier, rather than NMI of clustering.
This code implements REFILLED under the setting where a source task and a target task is given. main.py is the main file and the arguments it take are listed below.
data_name
: name of datasetteacher_network_name
: architecture of teacher modelstudent_network_name
: architecture of student model
devices
: list of gpu idsflag_gpu
: whether to use gpu or notflag_no_bar
: whether to use a barn_workers
: number of workers in data loaderflag_tuning
: whether to tune the hyperparameters on validation set or train on the whole training set
lr1
: initial learning rate in stage 1lr2
: initial learning rate in stage 2point
: when to decrease the learning rategamma
: the extent of learning rate decreasewd
: weight decaymo
: momentum
depth
: depth of resnet and wide_resnetwidth
: width of wide_resnetca
: channel coefficient of mobile_netdropout_rate
: dropout rate of the network
n_training_epochs1
: number of training epochs in stage 1n_training_epochs2
: number of training epochs in stage 2batch_size
: batch size in trainingtau1
: temperature for stochastic triplet embedding in stage 1tau2
: temperature for local distillation in stage 2lambd
: weight of teaching loss in stage 2