dheerajrajagopal/SelfExplain

Inaccurate loss

akshaylive opened this issue · 0 comments

According to the paper, section 2.5, the final loss is calculated as a weighted combination of the loss terms, LIL loss, GIL loss and task-based CE loss. However, in the code, the logits are calculated as the weighted sum of LIL, GIL and task-based logits BEFORE computation of the final loss. Due to this, the two are not equivalent.

i.e.; log(softmax(a+b)) is not equivalent to log(softmax(a)) + log(softmax(b))