BloodAxe/pytorch-toolbelt

Dice Loss/Score question

JanSobus opened this issue · 5 comments

Hey Eugene,

First of all, thank you for this very useful package. I'm transferring my environment from TF to Pytorch now and having your advanced losses is very helpful. However, when I trained the same model on the same data using same loss functions in both frameworks, I noticed that I get very different loss numbers (I'm using multilabel approach). Digging a little deeper in your code I noticed that when you calculate the Dice Loss you always calculate per sample AND per channel loss and then average it. I don't understand why are you doing the per channel calculation ad averaging, and not the Dice loss for all classes together. I can show What I mean on a dummy example below:

Let's prepare 2 dummy multilabel matrices - ground truth (d_gt) and prediction (d_pr) with 3 classes each, 0 Red, 1 Green and 2 Blue:
d_gt = np.zeros(shape=(20,20,3))
d_gt[5:10,5:10,0] =1
d_gt[10:15,10:15,1] =1
d_gt[:,:,2] = (1 - d_gt.sum(axis=-1, keepdims=True)).squeeze()
plt.imshow(d_gt)

image

d_pr = np.zeros(shape=(20,20,3))
d_pr[4:9,4:9,0] =1
d_pr[11:14,11:14,1] =1
d_pr[:,:,2] = (1 - d_pr.sum(axis=-1, keepdims=True)).squeeze()
plt.imshow(d_pr)

image

One can see that (using Dice Loss = 1- Dice Score):

  • Dice Loss for Red is 1- ((16+ 16) / (25+ 25)) = 0.36
  • Dice Loss for Green is 1 - ((9+9)/(9+25) = 0.4706
  • Dice Loss for Blue is 1 - ((341+341)/(350+366)) = 0.0474

However, total Dice Loss for the whole picture is 1 - (2*(16+9+341)/(2*400) = 0.085

After wrapping them into tensors
d_gt_tensor = torch.from_numpy(np.transpose(d_gt,(2,0,1))).unsqueeze(0)
d_pr_tensor = torch.from_numpy(np.transpose(d_pr,(2,0,1))).unsqueeze(0)
what your Dice Loss (with from_logits=False) is returning is 0.2927 which is the averaged loss of individual channels instead of the total loss. The culprit seems to be passing dims=(0,2) to the soft_dice_score function, I think that dims=(1,2) should be passed instead to get individual scores for each item in the batch? Unless this behaviour is intended but then I'd need some more explanation why.

Second smaller question regrading your Dice Loss is why you use from_logits= True by default?

Thanks in advance!

Greetings! Thanks for your question and positive feedback in my lib. Let me break your question in parts, so it should be easier to answer.

I noticed that when you calculate the Dice Loss you always calculate per sample AND per channel loss and then average it. I don't understand why are you doing the per channel calculation ad averaging, and not the Dice loss for all classes together.

As I recall, the main driving force for this implementation was to ensure we have as many classes as possible present during loss computation. Let me give an example, suppose we have batch of 4 images with classes {A,B} in first image, {B,C} in second, {C,D}, in third and {A,D} in the fourth. The current implementation will compute the soft Dice (or IoU) score for each class within batch, so all classes {A,B,C,D} will have non-zero support value. In case when we compute loss per each image individually, we will have zero contribution on classes that are absent in that sample. I believe both implementations should converge to same accuracy, yet I don't have publications to prove it.

However, total Dice Loss for the whole picture is 1 - (2*(16+9+341)/(2*400) = 0.085

For binary and multi-label case I think current implementation is just perfect option. For multi-class segmentation it's also known to work good. For multi-class segmentation I see when proposed loss may work better. Can you please refer to any publication / post which describes why one would want to compute it differently?

Second smaller question regrading your Dice Loss is why you use from_logits= True by default?

That's due to fact that we don't want to have final activation layer in our models by default. It's recommended to use BCEWithLogits / CrossEntropyLoss losses, which compute log_sigmoid/log_softmax with greater numerical accuracy rather than simple x.sigmoid().log(). Since this is well-known convention in PyTorch, DiceLoss has this from_logits=True. In that case it will compute probabilities by itself using log_sigmoid().exp() trick.

Thanks for the quick reply!
I always understood that for area based losses (Dice, Jaccard) the loss is 1 - corresponding score. I was using the segmentation_models package (https://github.com/qubvel/segmentation_models) in TF and switched to the pytorch wversion which utilises your toolbelt after moving to PyTorch (https://github.com/qubvel/segmentation_models.pytorch). As I mentioned in opening post, what got me thinking was that having that during training similar metrics resulted in vastly different loss scores (using same metrics and loss functions in both frameworks). Even checking the loss scores and metrics on single gt_mask-pred_mask pairs, the metrics in both frameworks were the same (same Dice Score) but the loss values were different. TF version was showing loss = 1 - score as expected and Pytorch was not, which led to those dummy tests above. Now, I understand the premise of calculation on per class basis, but wouldn't it be sufficient to calculate score for the whole batch then (dims=None)? And it results in loss = 1 - score too. In all honesty, all the previous implementations of Dice Loss I stumbled upon didn't make a class distinction - examples (https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch , https://lars76.github.io/2018/09/27/loss-functions-for-segmentation.html) that's why your implementation really took me by surprise. And the underlying reason reason for all this digging is that I wasn't able to reproduce my TF results in Pytorch yet using same model (Unet with efficientnet b0 - b3 backbone) on my highly imbalanced dataset (2 important classes, very small part of image +background class, around 50% of samples consisting only background).

Other than that, I noticed 2 more things in your Dice Loss code that raise questions:

  1. Is channel ignoring going to work in MULTILABEL case? [108-111] y_true is already only 0s and 1s of shape [N,C,H,W] so if we try to ignore class/channel >1 the mask is gonna pass everything through.
  2. In the comment on line [120] shouldn't it say "Dice loss is undefined for empty classes" ?

Thanks again!

I am still struggling with this same issue, could you help me understand it better?

Running this piece of code makes the dice loss return 0, while none of the predictions were correct. I would have expected a loss of 1 instead. How can that be?

import torch
from segmentation_models_pytorch.losses import BINARY_MODE, DiceLoss

dice_loss = DiceLoss(
    mode=BINARY_MODE,
    from_logits=False,  # If True applies log_sigmoid or log_softmax only to y_pred.
  )

batch_size = 1
num_classes = 1
image_size = 10
zeros = torch.zeros(size=(batch_size, num_classes, image_size, image_size), dtype=torch.float64)
ones = torch.ones(size=(batch_size, num_classes, image_size, image_size), dtype=torch.float64)
print(dice_loss(ones, zeros))  # Prints 0

Adding an epsilon does seem to solve the error though:

zeros = torch.zeros(size=(batch_size, num_classes, image_size, image_size), dtype=torch.float64)
zeros += 1e-90  # Adding an epsilon does solve the error
ones = torch.ones(size=(batch_size, num_classes, image_size, image_size), dtype=torch.float64)
print(dice_loss(ones, zeros))  # Prints 1

The signature of all losses defined as forward(predictions, targets). So the second argument defines the ground-truth values.
Secondly, Dice metric is not defined when there are not positive targets (As you pass empty zeros tensor). To avoid NaN in loss, it falls back to zero.
Hope this clarifies why you getting zero output if the first case.

Thanks for your reply.
It is clear to me now that you would prefer a zero loss instead of a NaN loss.

However, why would you say that the dice loss is not defined when there are no positive targets?
Looking up the Dice metric at Wikipedia seems to suggest it is just an intersection over union for two sets. Do you have a source for this which I could read to understand it better?