/Self-Distillation

Improve a Model's accuracy by distilling knowledge to the earlier layers of the model. Improves accuracy and performance of lightweight DNN models

Primary LanguageJupyter NotebookMIT LicenseMIT

Self Distillation


This Repository contains the code for the Conference Paper: >**Self Distillation: Be your own Teacher**

When building lightweight Deep learning models, The trade off between accuracy and efficiency is the paramount concern. This repository demonstrates how using a combination of early exiting and knowledge distillation methods, it is possible to improve both efficiency and accuracy at the same time.


Knowledge Distillation Vs Self Distillation


In Delf Distillation, The main exit of a model acts as a teacher for the branch exits that are added earlier to the model's structure and trained. By using the knowledge of the main exit as a teaching source, the branches accuracy improves, and the branch predictions can be used as prediction outputs, reducing the overall processing time needed per input.

Self Distillation improves accuracy across a range of DNN model structures, even already very lightweight model designs such as squeezenet and efficientnet


Self Distilling also improves the accuracy of the original branch, meaning that this training process is beneficial even for non-branching models.


Source files are located in folder branching, with the class model defined in branching\core_v2.py

To run, go to notebooks and run jupyter notebook self_distillation.ipynb. This notebook contains a working example of the branching process and self distillation of a branched model.
in the "branch the model" section, "replace models/resnet_CE_entropy_finetuned.hdf5" with the model you wish to branch,
in the add_branch or add_distill function identify the branch structure in branch_layers (_branch_Distill and _branch_conv1 are predefined)
select the branch points in branch_points ("conv2_block1_out","conv2_block3_out" should work for any resNet model, otherwise change this name to the layer name reported by tensorflow using "model.summary()" )
Provide an individual loss for each exit, in our example this means 3 losses. the first loss is the main exit.
Only one optimizer per model is currently enabled