TiankaiHang/Min-SNR-Diffusion-Training

Correct min_snr weighting for v-prediction objective

parlance-zz opened this issue · 1 comments

Hello,

I just wanted to confirm the formulation for the loss weight using the v-prediction objective in the code is correct.

In guided_diffusion/gaussian_diffusion.py, lines 861 to 864 we have:
elif self.mse_loss_weight_type.startswith("vmin_snr_"): k = float(self.mse_loss_weight_type.split('vmin_snr_')[-1]) # min{snr, k} mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] / (snr + 1)

Should the snr inside the stack also be +1? Without it the loss weights are always < 1, the weighting for SNRs near 0 will also be near 0, and the weight for zero-terminal SNR would be == 0.

Thank you.

IMO, mathematically, $$w_t=\frac {\min(\text{SNR}, \gamma)} {\text{SNR}+1}$$ for v-objective is equivalent to $$w_t=\min(\text{SNR}, \gamma)$$ for x0 prediction. The latter is also close to 0 as SNR approaches 0.