bubbliiiing/unet-pytorch

你好,我想问问在unet_training代码中定义的dice_loss函数中temp_target[...,:-1] 是什么意思啊,切片索引的是啥?还有torch.sum相加中的axis是啥意思?

BaronDuan opened this issue · 0 comments

def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
n, c, h, w = inputs.size()
nt, ht, wt, ct = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)

temp_inputs = torch.softmax(inputs.permute(0, 2, 3, 1).contiguous().view(n, -1, c), -1)
temp_target = target.view(n, -1, ct)

# 计算dice loss
tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1])
fp = torch.sum(temp_inputs                        , axis=[0, 1]) - tp
fn = torch.sum(temp_target[..., :-1]              , axis=[0, 1]) - tp

score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
dice_loss = 1 - torch.mean(score)
return dice_loss