Calculate batch norm statistic loss on parallel training
dohe0342 opened this issue · 3 comments
dohe0342 commented
Hello, I have one question about batch norm statistic loss.
Consider parallel training. I have 8 GPUs. and 1 gpu can bear 128 batch size.
But you know, batch norm statistic loss is calculated on each machine and each machine share their gradients not whole batch(1024). And I think this can cause image quality degradation.
So, here is my question. How can I calculate batch norm statistic loss on parallel training just like calculating whole batch size not mini-batch
hkunzhe commented
If you are using DistributedDataParallel, try to convert BatchNorm layers to SyncBatchNorm ones.
dohe0342 commented
I know about SyncBatchNorm.But DeepInversion should calculate loss about each pixel and my gpu can't bear it.