unnir/cVAE

KLD term in loss_function()

fedeotto opened this issue · 1 comments

In loss_function instead of doing

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + K

shouldn’t we normalize the KLD term over the batch instead? For example doing torch.mean()

unnir commented

Please check the KL-divergence equation, there is no mean operation there. Also, KLD is just a scalar (not a vector).