SanghunYun/UDA_pytorch

kl divergence uses ori_prob instead of ori_log_prob

stellaywu opened this issue · 1 comments

Thanks for the nice implementation!

I'm trying to reproduce the results, wondering why for KL divergence loss you used original probability instead of original log probability to loss against augmented log probability. It looks different in the tensorflow implementation.

unsup_loss = torch.sum(unsup_criterion(aug_log_prob, ori_prob), dim=-1)
the original tensorflow implementation used
per_example_kl_loss = kl_for_log_probs( tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask

Thanks !

sorry realized it's a pytorch thing