/knowledge_distillation

An implementation of knowledge distillation technique in the paper Distilling the Knowledge in a Neural Network with pytorch_lightning, wandb and hydra.

Primary LanguagePythonMIT LicenseMIT

Knowledge distillation

An implementation of knowledge distillation technique in the paper Distilling the Knowledge in a Neural Network with pytorch_lightning, wandb and hydra.

Why do we need to distill our models?

Neural networks tend to be deep, with millions of weights and activations, for example Resnet50, transformer. These large models are compute-intensive which means that even with dedicated acceleration hardware, the inference pass (network evaluation) will take time. You might think that latency is an issue only in certain cases, such as autonomous driving systems, but in fact, whenever we humans interact with our phones and computers, we are sensitive to the latency of the interaction. We don't like to wait for search results or for an application or web-page to load, and we are especially sensitive in realtime interactions such as speech recognition. So inference latency is often something we want to minimize.

Large models are also memory-intensvie with millions of parameters. Moving around all of the data required to compute inference results consumes energy, which is a problem on a mobile device as well as in a server environment.

The storage and transfer of large neural network is also a challenge in mobile device environments, because of limitations on application sizes and long application download times.

For theses reasons, we want to compress the network as much as possible, to reduce the amount of bandwidth and compute required.

What is knowledge distillation?

Knowledge distillation is model compression method in which a small model is trained to mimic a pretrained, larger model (or ensemble of models). This training setting is reffered to as "teacher-student", where the large model is the teacher and the small model is the student.

In distillation, knowledge is transferred from the teacher model to the student by minimizing a loss function in which the target is the distribution of class probabilities predicted by the teacher model. The principal idea is the output probabilities of a trained model give more information than the labels because it assigns non-zero probabilities to incorrect classes. The relative probabilities of incorrect answers tell us a sample has a change of belonging to certain classes.

In many cases, one problem is that the probabilites assigned to incorrect classes by the teacher are often very small, very close to 0 and don't contribute much to the loss. So, we can soften the probabilities by applying a temperature T (softmax temperature).

where T is the temperature parameter. When T=1 we get the standard softmax function. As T grows, the probability distribution generated by the softmax function becomes softer, providing more information as to which classes the teacher found more similar to the predicted class.

Hinton et al., 2015 found that it is also beneficial to train the distilled model to produce the correct labels (based on the ground truth) in addition to the teacher's soft-labels. Hence, we also calculate the "standard" loss between the student's predicted class probabilities and the ground-truth labels (also called "hard labels/targets"). We dub this loss the "student loss". When calculating the class probabilities for the student loss we use T=1.

where x is the input, W are the student model parameters, y is the ground truth label, H is the cross-entropy loss function, σ is the softmax function parameterized by the temperature T, and α and β are coefficients. z_s and z_t are the logits of the student and teacher respectively.

Dataset and model

  • The dataset used in this implementation is CIFAR10, which consists of 60.000 colour images of 32x32 size in 10 classes. There are 50000 training images and 10000 test images.
  • The teacher model used is resnet50 and the student model is resnet18

How to run

  1. First we need to train the teacher model
python train.py model.name="teacher"  
  1. Then run the student model while distilling the knowledge from the trained teacher model
python train.py model.name="distill" 

Results

We can monitor the distill training in wandb interface, which shows that our distill model works quite well as the teacher model.

Reference

  1. https://intellabs.github.io/distiller/index.html#what-is-distiller