Can we use unified loss function for multiclass segmentation?
Closed this issue · 4 comments
xlar-sanjeet commented
Can we use unified loss function for multiclass segmentation?
mlyg commented
Absolutely! The easiest way is to modify the code where the scores are calculated for the individual classes. To do this, you need to have a one-hot encoding of the classes, and know which classes correspond to each axis. For example, there are three classes in the KiTS19 dataset example (background, kidney and kidney tumour - corresponding to axes 0, 1 and 2 respectively), where the kidney and kidney tumour are much smaller than the background, and so the asymmetric Focal and asymmetric Focal Tversky losses can be modified to:
################################
# Asymmetric Focal loss #
################################
def asymmetric_focal_loss(delta=0.7, gamma=2.):
def loss_function(y_true, y_pred):
"""For Imbalanced datasets
Parameters
----------
delta : float, optional
controls weight given to false positive and false negatives, by default 0.7
gamma : float, optional
Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
"""
axis = identify_axis(y_true.get_shape())
epsilon = K.epsilon()
y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
cross_entropy = -y_true * K.log(y_pred)
# calculate losses separately for each class, only suppressing background class
# modify this section below for multiclass segmentation - note the additional axis added to account for 3D
segmentation
back_ce = K.pow(1 - y_pred[:,:,:,:,0], gamma) * cross_entropy[:,:,:,:,0]
back_ce = (1 - delta) * back_ce
kidney_ce = cross_entropy[:,:,:,:,1]
kidney_ce = delta * fore_ce
tumour_ce = cross_entropy[:,:,:,:,2]
tumour_ce = delta * fore_ce
loss = K.mean(K.sum(tf.stack([back_ce, kidney_ce, tumour_ce],axis=-1),axis=-1))
return loss
return loss_function
#################################
# Asymmetric Focal Tversky loss #
#################################
def asymmetric_focal_tversky_loss(delta=0.7, gamma=0.75):
"""This is the implementation for binary segmentation.
Parameters
----------
delta : float, optional
controls weight given to false positive and false negatives, by default 0.7
gamma : float, optional
focal parameter controls degree of down-weighting of easy examples, by default 0.75
"""
def loss_function(y_true, y_pred):
# Clip values to prevent division by zero error
epsilon = K.epsilon()
y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
axis = identify_axis(y_true.get_shape())
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = K.sum(y_true * y_pred, axis=axis)
fn = K.sum(y_true * (1-y_pred), axis=axis)
fp = K.sum((1-y_true) * y_pred, axis=axis)
dice_class = (tp + epsilon)/(tp + delta*fn + (1-delta)*fp + epsilon)
# calculate losses separately for each class, only enhancing foreground class
# modify this section below for multiclass segmentation
back_dice = (1-dice_class[:,0])
kidney_dice = (1-dice_class[:,1]) * K.pow(1-dice_class[:,1], -gamma)
tumour_dice = (1-dice_class[:,2]) * K.pow(1-dice_class[:,2], -gamma)
# Average class scores
loss = K.mean(tf.stack([back_dice, kidney_dice, tumour_dice],axis=-1))
return loss
return loss_function
xlar-sanjeet commented
Thank you
…On Thu, 14 Apr 2022 at 15:18, mlyg ***@***.***> wrote:
Closed #8 <#8>.
—
Reply to this email directly, view it on GitHub
<#8 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AUKYQWTCQUKVYFQDTMX2VVTVE7SXDANCNFSM5K33XJ4A>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>