wyu-du/GP-VAE

prior_logvar should be 1 when calculating KL

NIL-zhuang opened this issue · 1 comments

Thank you for your great work! I am learning VAE recently, your paper have given me great inspiration.

In naive VAE, the KL divergence should be $KL(\mathcal{N}(\mu, \sigma), \mathcal{N}(0,1))$. But when reading your code, I found model/t5/t5_vae.py line 151, you set the prior_logvar as 0. Is there any mistake?

prior_mean = torch.zeros([hidden_states.size(0), posterior_mean.size(-1)]) \
    .to(posterior_mean.dtype).to(posterior_mean.device)
prior_logvar = torch.zeros([hidden_states.size(0), posterior_logvar.size(-1)]) \
      .to(posterior_logvar.dtype).to(posterior_logvar.device)

Sorry I forget the log🙉