DequanWang/tent

Why is entire model set to train, instead of only BatchNorm modules

hector-gr opened this issue · 2 comments

Model is set to train mode in TENT so that BatchNorm modules use batch statistics at test time. However, this also sets other modules to train mode, for instance Dropout.

tent/tent.py

Lines 96 to 110 in e9e926a

def configure_model(model):
"""Configure model for use with tent."""
# train mode, because tent optimizes the model to minimize entropy
model.train()
# disable grad, to (re-)enable only what tent updates
model.requires_grad_(False)
# configure norm for tent updates: enable grad + force batch statisics
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.requires_grad_(True)
# force use of batch stats in train and eval modes
m.track_running_stats = False
m.running_mean = None
m.running_var = None
return model

It is also possible to set submodules to train mode only, and I believe this would achieve the desired behaviour for BatchNorm without affecting other modules. Is my understanding correct?

Notice this line: model.requires_grad_(False) . It disables grad for every module.
Then the for-loop below allows grad for BN.

As I understand it, .train() mode of a PyTorch Modules will set submodules like BatchNorm and Dropout to train mode (i.e. use batch statistics and update a running average of statistics, & drop some activations with certain probability; respectively). On the other hand, .requires_grad() will ultimately control the computations of grad for a certain module parameter.