HuCaoFighting/Swin-Unet

Dice loss is weird

hdnminh opened this issue · 1 comments

Hi @HuCaoFighting

I hope you have a great last week of the year!

Can you explain your implemented dice loss:

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

Why are y_sum and z_sum calculated by square of target and score tensor, respectively? Following the formula, we just sum it basically without squaring itself.

If anyone cares this, you can read V-Net