Dice loss is weird
hdnminh opened this issue · 1 comments
hdnminh commented
I hope you have a great last week of the year!
Can you explain your implemented dice loss:
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
Why are y_sum and z_sum calculated by square of target and score tensor, respectively? Following the formula, we just sum it basically without squaring itself.
hdnminh commented
If anyone cares this, you can read V-Net