prior_logvar should be 1 when calculating KL
NIL-zhuang opened this issue · 1 comments
NIL-zhuang commented
Thank you for your great work! I am learning VAE recently, your paper have given me great inspiration.
In naive VAE, the KL divergence should be
prior_mean = torch.zeros([hidden_states.size(0), posterior_mean.size(-1)]) \
.to(posterior_mean.dtype).to(posterior_mean.device)
prior_logvar = torch.zeros([hidden_states.size(0), posterior_logvar.size(-1)]) \
.to(posterior_logvar.dtype).to(posterior_logvar.device)
NIL-zhuang commented
Sorry I forget the log🙉