Question about KL in loss calculation
Opened this issue · 0 comments
Hi,
Thank you so much for this excellent work and the very nice code! I just got a question regarding the loss calculation (model/meme.py
lines 96 and 104) that I wish you could kindly discuss a little bit.
The first term in the original loss is
where I can identify the w
and kl
terms of the code. I notice that in the original loss, w
and kl
are actually multiplied together in the expectation, so I'd just sample z
from q_zs
and then calculate every conditional probability in the expectation. However, in the code you calculate kl
analytically, and w
by sampling z
. Could you explain a little bit about why? I think it is like calculating E[f(x)g(x)], where E[g(x)] has a close form, so you calculate E[f(x)] by MC sampling, and then multiply it by E[g(x)] to get final result, while I'd like to regard f(x)g(x) as a whole and calculate the entire expectation using sampling.
I wish I have explained my question clearly enough. Thanks a lot for your help in advance!
Best regards,
Xin