usuyama/pytorch-unet

Can dice and bce loss work on a multi-class task?

Dr-Cube opened this issue · 5 comments

Thanks for the great implementation code.

I am confusing about the loss function. As far as I can see dice and bce are both used in binary-class task. Can they work well on multi-class task? From your code I can see the losses work ok, but what about bigger data set.

I tried F.cross_entropy(), but it gives me this: RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [36, 4, 224, 224]. Could you please tell me whats wrong? thx

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    target_long = target.type(torch.LongTensor)
    ce = F.cross_entropy(pred, target_long.cuda())

    # pred = F.sigmoid(pred)
    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

You're right. The current code uses BCE, so each pixel can have multiple classes i.e. multi-labels.

To make it a single class for each pixel i.e. multi-class, you can use CE. I think you need to use reshape/view to 2d.

Why are the metrics multiplied by the batch size, added cumulatively and then divided by the total number of samples during printing?

It seems like this will print a scaled version of the average metric value (from print_metrics), depending on the batch size. Please correct me if I'm wrong.

Hi, I am having a problem dealing with a multi-class task where dimensions are like these:

MASK TARGET: torch.Size([4, 1, 600, 900, 3])
OUTPUT: torch.Size([4, 5, 600, 900]

@ckolluru Have you created your loss function for multiclass already?

Why are the metrics multiplied by the batch size, added cumulatively and then divided by the total number of samples during printing?

The reason is some batches (i.e. the last batch) may have fewer training examples than all the other batches. Dividing the product metric * batch_size by total_samples is a better estimate of a used metric for a complete epoch. Skimming through the example Training a Classifier from PyTorch tutorials reveals that the same strategy was used in section "4. Train the network" is the important one where the statistics were printed.

It seems like this will print a scaled version of the average metric value (from print_metrics), depending on the batch size. Please correct me if I'm wrong.

That's true! You may get a better insight into this topic by reading How is the loss for an epoch reported?

If you want to use cross entropy make sure you're not applying sigmoid function beforehand.