Why is entire model set to train, instead of only BatchNorm modules
hector-gr opened this issue · 2 comments
Model is set to train mode in TENT so that BatchNorm modules use batch statistics at test time. However, this also sets other modules to train mode, for instance Dropout.
Lines 96 to 110 in e9e926a
It is also possible to set submodules to train mode only, and I believe this would achieve the desired behaviour for BatchNorm without affecting other modules. Is my understanding correct?
Notice this line: model.requires_grad_(False)
. It disables grad for every module.
Then the for-loop below allows grad for BN.
As I understand it, .train() mode of a PyTorch Modules will set submodules like BatchNorm and Dropout to train mode (i.e. use batch statistics and update a running average of statistics, & drop some activations with certain probability; respectively). On the other hand, .requires_grad() will ultimately control the computations of grad for a certain module parameter.