loss 计算的IOU weight 是否有bug?
terenceau2 opened this issue · 1 comments
terenceau2 commented
bin_loss = self._filter_criterion(bin_logits, bin_entity_types) * ious
请看loss。py的第64行
如此的话,self._filter_criterion(bin_logits, bin_entity_types) 得出的shape是(num_of_spans,1) 而ious 的shape 是(num_of_spans,)
如此bin_loss就会是outer product,而非element wise multiplication
是否应该改为
bin_loss = self._filter_criterion(bin_logits, bin_entity_types) * ious.unsqueeze(1)
如此得到的shape是(num_of_spans,1),每个candidate的loss。
entity_loss 亦是同理
tricktreat commented
您好,self._filter_criterion(bin_logits, bin_entity_types)的形状也是num_of_spans,),因此这里是element wise multiplication。