Bug in categorical_focal_loss ?
Closed this issue · 4 comments
I have tested the code of 'categorica_focal_loss' and I think I have found a bug. If I understand it well, focal_loss should be equal to crossentropy if gamma=0.0 and alpha=1.0. However that is not the case. If I compare the output obtained against keras crossentropy, the result is equal up to a factor.
a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])
print(a)
b = tf.constant([.9, .05, .05, .5, .89, .6, .05, .01, .94], shape=[3,3])
print(b)
loss = tf.keras.backend.categorical_crossentropy(a, b)
print(np.around(loss, 5))
cfl = categorical_focal_loss(gamma=0., alpha=1.)
loss = cfl(a, b)
It turns out that your function should return the sum of the loss, instead of the mean. That is
return K.sum(loss, axis=-1)
Am I right ?
I feel you are right but i'm not sure
Hi @DanielPonsa, we have replaced the sum in this PR because there was a bug related to batching.
You are right about the fact that with gamma equal to zero and alpha equal to one, the categorical focal loss should be equal to the categorical focal loss. However, if you use the sum, still you have weird numbers because you are summing the loss for each batch.
Any suggestions?
Hi @umbertogriffo,
with the suggested code (K.sum(loss, axis=-1)) you obtain the loss for each pixel in a batch size, mimicking the behaviour of the keras function tf.keras.losses.categorical_crossentropy. If what you want is mimicking the behaviour of the class tf.keras.losses.CategoricalCrossentropy (which perform reduction by default) then you have to compute the mean value afterwards in an additional step, or put K.mean(K.sum(loss, axis=-1)) in the return of your function.
From what is stated in https://keras.io/api/losses/ , training should be equivalent in both cases, since in the webpage is stated the following:
"Note that this is an important difference between loss functions like tf.keras.losses.mean_squared_error and default loss class instances like tf.keras.losses.MeanSquaredError: the function version does not perform reduction, but by default the class instance does.
When using fit(), this difference is irrelevant since reduction is handled by the framework."
However, if you have found some problem, maybe the preceeding statement is imprecise, and some difference exist.
In the following I put the code to check the categorical_focal_loss code against categorical_crossentropy, to prove they behave in the same way.
# Pixel-based batch size
y_true = np.array([[0, 1, 0], [0, 0, 1]])
y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
print("Pixel-based labelling")
print("Data dimension as [batch_size (amount of pixels), one_hot_encoding of pixel label]")
print(y_true.shape)
print("categorical_cross_entropy")
cce = tf.keras.losses.categorical_crossentropy
print(cce(y_true, y_pred).numpy())
print("focal_loss")
cfl = categorical_focal_loss(gamma=0., alpha=1.)
print(cfl(y_true, y_pred).numpy())
# Image based batch size
y_true = np.array( [[[[1, 0, 0, 0],[0, 1, 0, 0]],[[0, 0, 0, 1],[0, 0, 1, 0]]], \
[[[0, 1, 0, 0],[0, 1, 0, 0]],[[1, 0, 0, 0],[0, 0, 0, 1]]]])
y_pred = np.array( [[[[0.8, 0.0, 0.2, 0.0],[0.0, 0.95, 0.0, 0.05]],[[0.1, 0.2, 0.3, 0.4],[0.5, 0.0, 0.5, 0.0]]], \
[[[0.0, 0.6, 0.0, 0.4],[0.1, 0.80, 0.1, 0.00]],[[0.7, 0.0, 0.3, 0.0],[0.2, 0.0, 0.3, 0.5]]]])
print("Image-based labelling")
print("Data dimension as [batch_size (amount of images), height, width, one_hot_encoding of pixel label]")
print(y_true.shape)
print("categorical_cross_entropy")
cce = tf.keras.losses.categorical_crossentropy
print(cce(y_true, y_pred).numpy())
print("focal_loss")
cfl = categorical_focal_loss(gamma=0., alpha=1.)
print(cfl(y_true, y_pred).numpy())
Hi @DanielPonsa,
I wanted to mimic the behaviour of the class tf.keras.losses.CategoricalCrossentropy so I tried to put K.mean(K.sum(loss, axis=-1))
and now they behave in the same way.
Using your code I got:
categorical_cross_entropy
1.176939193690798
focal_loss
1.176939193690798
and
categorical_cross_entropy
0.4584582572143423
focal_loss
0.4584582572143423
I'm going to make this change and adding a unit test as well.
Thanks a lot for the contribute!