why specify `ignore_index=0` in the NLLLoss function in BERTTrainer?
Jasmine969 opened this issue · 1 comments
Jasmine969 commented
trainer/pretrain.py
class BERTTrainer:
def __init__(self, ...):
...
# Using Negative Log Likelihood Loss function for predicting the masked_token
self.criterion = nn.NLLLoss(ignore_index=0)
...
I cannot understand why ignore index=0
is specified when calculating NLLLoss. If the ground truth of is_next
is False (label = 0) in terms of the NSP task but BERT predicts True, then NLLLoss will be 0 (or nan)... so what's the aim of ignore_index = 0
???
====================
Well, I've found that ignore_index = 0
is useful to the MLM task, but I still can't agree the NSP task should share the same NLLLoss with MLM.
MingchangLi commented
see #32
change
self.criterion = nn.NLLLoss(ignore_index=0)
to
self.criterion = nn.NLLLoss()