mlyg/unified-focal-loss

Can we use unified loss function for multiclass segmentation?

Closed this issue · 4 comments

Can we use unified loss function for multiclass segmentation?
mlyg commented

Absolutely! The easiest way is to modify the code where the scores are calculated for the individual classes. To do this, you need to have a one-hot encoding of the classes, and know which classes correspond to each axis. For example, there are three classes in the KiTS19 dataset example (background, kidney and kidney tumour - corresponding to axes 0, 1 and 2 respectively), where the kidney and kidney tumour are much smaller than the background, and so the asymmetric Focal and asymmetric Focal Tversky losses can be modified to:

################################
#     Asymmetric Focal loss    #
################################
def asymmetric_focal_loss(delta=0.7, gamma=2.):
    def loss_function(y_true, y_pred):
    """For Imbalanced datasets
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.7
    gamma : float, optional
        Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
    """
        axis = identify_axis(y_true.get_shape())  

        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
        cross_entropy = -y_true * K.log(y_pred)

        # calculate losses separately for each class, only suppressing background class 
        # modify this section below for multiclass segmentation - note the additional axis added to account for 3D 
        segmentation
        back_ce = K.pow(1 - y_pred[:,:,:,:,0], gamma) * cross_entropy[:,:,:,:,0]
        back_ce =  (1 - delta) * back_ce

        kidney_ce = cross_entropy[:,:,:,:,1]
        kidney_ce = delta * fore_ce

        tumour_ce = cross_entropy[:,:,:,:,2]
        tumour_ce = delta * fore_ce

        loss = K.mean(K.sum(tf.stack([back_ce, kidney_ce, tumour_ce],axis=-1),axis=-1))

        return loss

    return loss_function

#################################
# Asymmetric Focal Tversky loss #
#################################
def asymmetric_focal_tversky_loss(delta=0.7, gamma=0.75):
    """This is the implementation for binary segmentation.
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.7
    gamma : float, optional
        focal parameter controls degree of down-weighting of easy examples, by default 0.75
    """
    def loss_function(y_true, y_pred):
        # Clip values to prevent division by zero error
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

        axis = identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)     
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1-y_pred), axis=axis)
        fp = K.sum((1-y_true) * y_pred, axis=axis)
        dice_class = (tp + epsilon)/(tp + delta*fn + (1-delta)*fp + epsilon)

        # calculate losses separately for each class, only enhancing foreground class
        # modify this section below for multiclass segmentation
        back_dice = (1-dice_class[:,0]) 
        kidney_dice = (1-dice_class[:,1]) * K.pow(1-dice_class[:,1], -gamma) 
        tumour_dice = (1-dice_class[:,2]) * K.pow(1-dice_class[:,2], -gamma)

        # Average class scores
        loss = K.mean(tf.stack([back_dice, kidney_dice, tumour_dice],axis=-1))
        return loss

    return loss_function

Hi @mlyg. Could you update your code to add this multiclass version? It has been some time since your paper and I guess lots of people would benefit from a multiclass version, but I can't find one anywhere besides this issue. Thanks!

mlyg commented

@pedrohbd Thanks for bring this up - it makes a lot of sense to add the multiclass versions to the main code. I will look to get this done as soon as I can.

Best wishes,
Michael