kanyun-inc/fairseq-gec

Dropout causing negative loss values.

PhilippeMarcotte opened this issue · 5 comments

Hello,

Pytorch implements inverted dropout which scales up the values. This means that the weights in the MultiHeadAttention module from fairseq can go over 1 since dropout is applied after the softmax.

Since the model uses the attention weights from the copy module to compute the final probability distribution, p_t(w), values over 1 are also possible. In turn, this makes it possible for the cross entropy loss to be negative. Cross entropy expect a probability distribution which is not what is given during the training of this model.

Is it something that you considered?

I've noticed that no negative loss occur using your training data which surprise me. Maybe there is something I'm not seeing in the way you compute the loss. It seems to be related to the labels made from the train.forward file. However, using my own training data (and your code), it does happen.

Thank you.

I was able to verify that even the data used in the paper sometimes produces negative losses. The total loss per batch is always positive. However, I think that's only because the total loss does not go low enough. Some individual words do produce negative losses.

To verify this, I've only modified the way the loss is computed in cross_entropy.py

def compute_weighted_loss(self, model, net_output, sample, reduce=True):
    lprobs = model.get_normalized_probs(net_output, log_probs=True, sample=sample)
    lprobs = lprobs.view(-1, lprobs.size(-1))
    target = model.get_targets(sample, net_output).view(-1)

    target_label = sample['target_label'].view(-1).byte()
    neg_target = target.new_tensor(target).masked_fill_(target_label, self.padding_idx)
    pos_target = target.new_tensor(target).masked_fill_(1-target_label, self.padding_idx)

    neg_loss = F.nll_loss(lprobs, neg_target, size_average=False, ignore_index=self.padding_idx,
                            reduce=False)
    pos_loss = F.nll_loss(lprobs, pos_target, size_average=False, ignore_index=self.padding_idx,
                            reduce=False)
    #loss = neg_loss + self.args.positive_label_weight * pos_loss
    if (neg_loss < 0).sum() > 0:
        print("neg_loss:\n", "  value:", neg_loss.sum().item(), "\n  # negative losses:", (neg_loss < 0).sum().item())
        print("pos_loss:\n", "  value:", pos_loss.sum().item(), "\n  # negative losses:", (pos_loss < 0).sum().item())
    neg_loss = neg_loss.sum()
    pos_loss = pos_loss.sum()
    loss = (1/self.args.positive_label_weight) * neg_loss + pos_loss

    return loss, loss

Here are some results:

neg_loss:
   value: 162.010986328125 
  # negative losses: 185
pos_loss:
   value: 565.7987060546875 
  # negative losses: 0
neg_loss:
   value: 217.8792266845703 
  # negative losses: 101
pos_loss:
   value: 1025.3006591796875 
  # negative losses: 0
neg_loss:
   value: 328.1896057128906 
  # negative losses: 151
pos_loss:
   value: 985.190673828125 
  # negative losses: 0
neg_loss:
   value: 206.05780029296875 
  # negative losses: 117
pos_loss:
   value: 472.4427490234375 
  # negative losses: 0
neg_loss:
   value: 1004.4150390625 
  # negative losses: 783
pos_loss:
   value: 1952.7119140625 
  # negative losses: 1

# negative losses correspond to the number of words per batch for each neg_loss and pos_loss that produce negative losses.

Very good catch and thank you for pointing it out with all the details.
I am testing the performance by adding the following line of code:
attn_weights = torch.clamp(attn_weights, 0, 1)

I've tested taking the attn_weights before dropout. What I mean by that is copy_alpha is still computed with the dropped out attention weights but p_t^copy correspond to the attention weights before dropout.

On the paper's data, it seems to keep a similar copy_alpha. However, I did not test it more thoroughly than that.

I did test it on my data and the model's generalization capacity seems negatively affected. The reduced effect of dropout is probably the cause.

Yes, in our experiment, the dropout over the copy attention improves the performance.

Bug fixed.