TiankaiHang/Min-SNR-Diffusion-Training

v-prediction implicit weighting (appendix math)

Closed this issue · 1 comments

First, congratulations on a very interesting paper!

I'm attempting to follow the math in the appendix, which derives implicit weighting schemes for different objectives. I follow the noise-prediction math in its entirety.

For v-prediction, I follow to the 5th line (looking the paper version on Arxiv, v2: https://arxiv.org/abs/2303.09556), then I arrive at a different result (abbreviating the syntax slightly):

$$\ell = \| \frac{ \alpha_t^2 + \sigma_t^2 }{\sigma_t} (x_0 - \hat{x}_{\theta} ) \|^2_2 \\$$ $$= \frac{1}{\sigma_t^2} \| (\alpha_t^2 + \sigma_t^2 ) (x_0 - \hat{x}_{\theta} ) \|^2_2 \\$$ $$= \frac{(\alpha_t^2 + \sigma_t^2)^2}{\sigma_t^2} \| (x_0 - \hat{x}_{\theta} ) \|^2_2 \\$$ $$= \frac{\sigma_t^2}{\sigma_t^2}\frac{(\alpha_t^2 + \sigma_t^2)^2}{\sigma_t^2} \| (x_0 - \hat{x}_{\theta} ) \|^2_2 \\$$ $$= \sigma_t^2 (\frac{\alpha_t^2 + \sigma_t^2}{\sigma_t^2} )^2 \| (x_0 - \hat{x}_{\theta} ) \|^2_2 \\$$ $$= \sigma_t^2 (\frac{\alpha_t^2}{\sigma_t^2} + 1 )^2 \| (x_0 - \hat{x}_{\theta} ) \|^2_2 \\$$ $$= \sigma_t^2 ( {SNR}_t + 1 )^2 \| (x_0 - \hat{x}_{\theta} ) \|^2_2 \\$$

Should I not be squaring factors pulled out of the squared $L_2$ objective? If I do not, then I wouldn't get $\sigma_t^2$ in the denominator, and would not complete the SNR term. However, if I do, then it seems to me I must also square $(\alpha_t^2 - \sigma_t^2)$, which I think differs from the authors' result, if I am understanding their work. The noise prediction proof has a similar step, which does square terms pulled out of the squared $L_2$ distance.

I hope I have not made a trivial error, but it is certainly possible, and I apologize if that is the case.

An unrelated question, but have you compared the weighting scheme you propose against the implicit weighting scheme of v-prediction?

Thanks for your help, and congrats again on a very interesting result!

I realize my error -- I neglected that $\alpha_t = \sqrt{1 - \sigma_t^2}$, so $\alpha_t^2 + \sigma_t^2 = 1$, and my expression is equivalent to ${SNR}_t + 1$.