Correct min_snr weighting for v-prediction objective
parlance-zz opened this issue · 1 comments
Hello,
I just wanted to confirm the formulation for the loss weight using the v-prediction objective in the code is correct.
In guided_diffusion/gaussian_diffusion.py, lines 861 to 864 we have:
elif self.mse_loss_weight_type.startswith("vmin_snr_"): k = float(self.mse_loss_weight_type.split('vmin_snr_')[-1]) # min{snr, k} mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] / (snr + 1)
Should the snr inside the stack also be +1? Without it the loss weights are always < 1, the weighting for SNRs near 0 will also be near 0, and the weight for zero-terminal SNR would be == 0.
Thank you.
IMO, mathematically,