
Loss computations for CLUE

jmayank23 opened this issue · 1 comments

I was really interested in the CLUE method, so I was looking at the code and read the paper to implement it. I noticed that the KL divergence terms in the paper and the code looked different, so I was wondering if you could help me with this part; if you could share which approach worked better and like some details about how to compute it, that would be really helpful.
Sharing reference screenshots here-

Screen Shot 2023-06-08 at 12 13 34 PM

Screen Shot 2023-06-08 at 12 14 26 PM

Thank you

Hi @jmayank23. Thanks for your interest in CLUE!

The differences you mentioned are probably the followings:

  1. The code normalizes both the negative log likelihood (.mean()) and the KL term (/ x[k].shape[1]) by the number of features in the data modality, so that different modalities receive more or less the same attention during training.
  2. The code has a hyperparameter self.lam_kl that controls the weight of the KL term, which was optimized using cross-validation. This was omitted in the paper to steer clear for the main message of cross encoding.

Let me know if there are further problems.