boschresearch/OASIS

EMA, 3D noise

pmh9960 opened this issue · 2 comments

Hi. @edgarschnfld @SushkoVadim

I am a student studying semantic image synthesis. Thank you for the great work. I have two questions about the difference between paper and code.

  1. EMA
    As you cite [Yaz et al., 2018], exponential moving average is a good technique for training GAN. However, in your code

    OASIS/utils/utils.py

    Lines 125 to 132 in 6e728ec

    def update_EMA(model, cur_iter, dataloader, opt, force_run_stats=False):
    # update weights based on new generator weights
    with torch.no_grad():
    for key in model.module.netEMA.state_dict():
    model.module.netEMA.state_dict()[key].data.copy_(
    model.module.netEMA.state_dict()[key].data * opt.EMA_decay +
    model.module.netG.state_dict()[key].data * (1 - opt.EMA_decay)
    )

    I think below code might be added
model.module.netG.state_dict()[key].data.copy_(
    model.module.netEMA.state_dict()[key].data
)

If not, netG is not trained using EMA.

Yaz, Yasin, et al. "The unusual effectiveness of averaging in GAN training." International Conference on Learning Representations. 2018.

  1. 3D noise

If I do not misunderstand your paper, the paper says that the noise of OASIS has been sampled from a 3D normal distribution. And this is one of the main differences with SPADE.
However, in your code at,

if not self.opt.no_3dnoise:
dev = seg.get_device() if self.opt.gpu_ids != "-1" else "cpu"
z = torch.randn(seg.size(0), self.opt.z_dim, dtype=torch.float32, device=dev)
z = z.view(z.size(0), self.opt.z_dim, 1, 1)
z = z.expand(z.size(0), self.opt.z_dim, seg.size(2), seg.size(3))
seg = torch.cat((z, seg), dim = 1)

Noise is not sampled from the 3D normal distribution. It was also sampled from a 1D normal distribution. Then expand it to 3D, which replicates the same vector spatial way.
In my opinion, this code should be replaced by

z = torch.randn(seg.shape, ...)

I think both two parts are pretty crucial for your paper. If there is any reason for these choices or my fault, please let me know.

Thank you.

Hi,

  1. The netEMA checkpoint is meant simply to track the running average of weights of the generator network netG. When used at inference instead of netG, it has indeed the potential to improve performance. It is usually not meant to be used during training. In your suggested example, netG is not allowed to be different from netEMA, which imposes a strong constraint on the generator. This will likely impair the training by making it much harder for the generator to fool the discriminator netD.
  2. We indeed do not assume sampling from a 3D normal distribution, and use a simpler "replication" strategy. Please refer to Appendix A.7 in the paper for the related discussion, and the paragraph in Sec. 3.3 in the main paper:

Note that for simplicity during training we sample the 3D noise tensor globally, i.e. per-channel, replicating each channel value spatially along the height and width of the tensor. We analyse alternative ways of sampling 3D noise during training in App. A.7.

Thank you for fast reply and detailed explanation!
I will consider it in my work.

Thank you.