EMA Update of bn buffer
luyvlei opened this issue ยท 5 comments
The following function apply moving average to the ema model. But it didn't update the statistic(runing_mean and runing_var) since these two were not parameters but buffers.
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
Should I use this function instead?
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
for current_buffers, ma_buffers in zip(current_model.buffers(), ma_model.buffers()):
old_weight, up_weight = ma_buffers.data, current_buffers.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
@luyvlei ohh yes, i do believe you are correct, and this paper also came to a similar conclusion https://arxiv.org/abs/2101.07525
Do I need to do a test to verify this modification? If this modification is effective, I can submit a PR and test report. @lucidrains
@luyvlei so i think the issue is because the batchnorm statistics are already a moving average - i'll have to read the momentum squared paper above in detail and see if the conclusions are sound
as an aside, there are papers that are starting to use SimSiam (kaimings work where the teacher is the same as the student, but with a stop gradient) successfully, and which does not require exponential moving averages as does BYOL. so i'm wondering how important these little details are, and whether it is worth the time to even debug
https://arxiv.org/abs/2111.00210
https://arxiv.org/abs/2110.05208
Hi all, I am also trying to reproduce BYOL results and am falling a bit short (~1%) and am wondering if this might be related to the reason why.
I figure there are two options during pretraining:
- Use local batch statistics to batchnormalize (In Pytorch target model must be in train mode -- current state in this repo).
- Use running mean and running var to batchnormalize (In Pytorch target model must be in eval mode, while online network should be in train mode).
I believe # 1 is correct based on my reading of the paper and looking through some implementations. If #1 is correct, there don't need to be any changes -- also since we feed the same exact images to the target and online network, the running mean and running var calculated should be the same in the end.
If # 2 is correct, then we would have to copy the buffers as suggested above.
As an aside, I believe the issue in my repro is that I am following # 1 and have SyncBatchNorm for the online network, but not for the target network.