DiceLoss uses 'sum' reduction, but CrossEntropyLoss uses 'mean' reduction
jwc-rad opened this issue · 0 comments
jwc-rad commented
Thank you for the awesome repository!
I've noticed torch.nn.CrossEntropyLoss
is used for the cross entropy loss and a custom loss from utils.losses
is used for the Dice loss, as used as follows:
SSL4MIS/code/train_uncertainty_aware_mean_teacher_3D.py
Lines 124 to 125 in 30e05d8
The Dice loss seems to use a 'sum' reduction as follows:
Lines 169 to 177 in 30e05d8
However, the default reduction method for
torch.nn.CrossEntropyLoss
is 'mean', so the Dice loss is always roughly about H*W(*D)
times bigger than the CE loss.So, a direct mean of two losses as used in the following code would not be actually the intended average.
Although I am sure this has minimal effects on most of your SSL methods because it is simply using Dice instead of Dice + CE for the supervisised loss, but still I think it should be checked.