HiLab-git/SSL4MIS

DiceLoss uses 'sum' reduction, but CrossEntropyLoss uses 'mean' reduction

jwc-rad opened this issue · 0 comments

Thank you for the awesome repository!

I've noticed torch.nn.CrossEntropyLoss is used for the cross entropy loss and a custom loss from utils.losses is used for the Dice loss, as used as follows:

ce_loss = CrossEntropyLoss()
dice_loss = losses.DiceLoss(2)

The Dice loss seems to use a 'sum' reduction as follows:

def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss

However, the default reduction method for torch.nn.CrossEntropyLoss is 'mean', so the Dice loss is always roughly about H*W(*D) times bigger than the CE loss.
So, a direct mean of two losses as used in the following code would not be actually the intended average.
supervised_loss = 0.5 * (loss_dice + loss_ce)

Although I am sure this has minimal effects on most of your SSL methods because it is simply using Dice instead of Dice + CE for the supervisised loss, but still I think it should be checked.