KLD term in loss_function()
fedeotto opened this issue · 1 comments
fedeotto commented
In loss_function
instead of doing
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + K
shouldn’t we normalize the KLD term over the batch instead? For example doing torch.mean()
unnir commented
Please check the KL-divergence equation, there is no mean operation there. Also, KLD is just a scalar (not a vector).