clovaai/overhaul-distillation

Does t_net also need to be trained?

TouchSkyWf opened this issue · 1 comments

d_net.train()
d_net.module.s_net.train()
d_net.module.t_net.train()

Hi,

I noticed that the Teacher model is also trained.
In the general distillation process, shouldn't the teacher be in inference mode? Both in training mode will cause the training speed to be extremely slow.

Shouldn't the teacher model be in inference mode?

bhheo commented

Hi @TouchSkyWf
Thank you for your interest in our research.

I set the teacher network in training mode.
But it doesn't mean the teacher network is trained in the distillation process.
Because the teacher's feature is detached in loss computation, backward propagation doesn't calculate gradients for the teacher network, which means no slowdown or additional computation due to teacher network training.

loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(self, 'margin%d' % (i+1))) \
/ 2 ** (feat_num - i - 1)

I want to emphasize that the training mode in PyTorch does nothing on gradient calculation.
It is used to set networks' modules to training or inference mode.

The training mode is set for BatchNorm in the teacher network.
Because BatchNorm uses running_mean and running_var in inference mode, we thought the mode of BatchNorm might affect distillation performance.
In Table 8, we observe that BatchNorm in training mode is better than in inference mode.
So, we set the teacher network to training mode for BatchNorm in the teacher network.

I think other methods such as with torch.no_grad(): on the teacher forward pass might increase the readability of our code.
But, I didn't know the method other than detach() when we wrote the paper.

Best
Byeongho Heo