thwjoy/meme

Question about KL in loss calculation

Opened this issue · 0 comments

Hi,

Thank you so much for this excellent work and the very nice code! I just got a question regarding the loss calculation (model/meme.py lines 96 and 104) that I wish you could kindly discuss a little bit.

The first term in the original loss is
image
where I can identify the w and kl terms of the code. I notice that in the original loss, w and kl are actually multiplied together in the expectation, so I'd just sample z from q_zs and then calculate every conditional probability in the expectation. However, in the code you calculate kl analytically, and w by sampling z. Could you explain a little bit about why? I think it is like calculating E[f(x)g(x)], where E[g(x)] has a close form, so you calculate E[f(x)] by MC sampling, and then multiply it by E[g(x)] to get final result, while I'd like to regard f(x)g(x) as a whole and calculate the entire expectation using sampling.

I wish I have explained my question clearly enough. Thanks a lot for your help in advance!

Best regards,
Xin