Qustion about the update method of ema_model
sanyouwu opened this issue · 5 comments
Hello:
I am Sanyou, thanks for your contribution and sorry for disturbing you again. I have two question as follows:
In the the class WeightEMA(object)
Firstly, What's it for ? (param.data.mul_(1 - self.wd)) .
Secondly, I am curious about why don't you use the update ema_model mothod in Mean-Teacher.(I tried it, but the result is completely failed. Additionally, If I don't use mix-up trick in your code framework(I think it should equals to Mean-Teacher), But I can't get ideal result as Mean-Teacher(1000 labels 77%(mine) < 79%(mean-teacher), 4000 labels 87% < 88%(mean-teacher ).
@YU1ut
Good questions.
First, param.data.mul_(1 - self.wd)
is weight decay.It is also used in the official code (https://github.com/google-research/mixmatch/blob/master/mixmatch.py#L92), but it is different from the normal weight decay of pytorch (which is using L2 penalty). So I implemented it directly in the calculation of EMA.
Secondly, Mean-Teacher only uses model.parameters()
to update EMA model but model.parameters()
do not contain the stats of BatchNorm. You can check it by printing the parameters of BatchNorm in model and EMA model. Mean-Teacher works well because they also feed samples into the EMA model during training which means EMA model can get correct mean and std, but in this code no samples are passed through EMA model in training so I copy the parameters of BatchNorm by model.state_dict()
at the end of each epoch to solve this problem.
The gap of accuracy maybe come from other factors like Adam vs SGD, learning rate scheduling and other hyper-parameters.
Thanks for your explanation. I noticed that you copy the BN parameters from self.tmp_model instead of self.model. Could you please explain why is this?
for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()):
tmp_param.data.copy_(ema_param.data.detach())
self.ema_model.load_state_dict(self.model.state_dict())
for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()):
ema_param.data.copy_(tmp_param.data.detach())
@xmengli999
He only wants to copy batch norm statistic (running mean and std which is crucial for inference) from self.model (current model) to ema_model at the end of every epoch.
This is because batch normalization parameters such as running_mean and running_var are not included in model.parameters(). However, they appear in the model[‘state_dict’].
I think for ema_model BN parameter should also be exponentially moving averaged? since, BN parameter of tmp_model = BN parameter of ema_model.
@YU1ut thanks for the implementation ! a few more questions on WeightEMA class:
- in the init part of the class there's
for param, ema_param in zip(self.params, self.ema_params):
param.data.copy_(ema_param.data)
What does it do?
I'm thinking it's copying ema_param.data to param.data for each parameter, but it's executing on param.data, doesn't change the values in self.params and self.ema_params .
- WeightEMA does two things: 1) updates ema_param with param and 2) add a weight decay on param on top of Adam optimizer ? (since optimizer.step() is called before calling WeightEMA in train() ) Why does weight decay happen after ema_param is calculated with the pre-decayed param values?
Thanks!