sicara/tf-explain

Support for Binary Classification Models

Closed this issue · 3 comments

Hi, first of all thank you for tf-explain.

Currently I'm trying to use tf-explain with a model like this one:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(initial_filters, kernel_size, activation='relu', input_shape=(256, 256, 3), padding="same")) 
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 128, 128
model.add(tf.keras.layers.Conv2D(initial_filters*2, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 64, 64
model.add(tf.keras.layers.Conv2D(initial_filters*4, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 32, 32
model.add(tf.keras.layers.Conv2D(initial_filters*8, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 16, 16
model.add(tf.keras.layers.Conv2D(initial_filters*16, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.MaxPooling2D((2, 2))) # 8, 8
model.add(tf.keras.layers.Conv2D(initial_filters*32, kernel_size, activation='relu', padding="same"))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(64, activation="relu"))
model.add(tf.keras.layers.Dense(1))

This is a model used for a binary classification task for the cat vs dog dataset.
Using the tf-explain callback GradCAM does not seem to provide correct result.

I think this is due to the following line in the code:

https://github.com/sicara/tf-explain/blob/master/tf_explain/core/grad_cam.py#L85

where basically you take the index corresponding to the selected class.
A better approach would be to check the shape of the model output and:

  • if the class is 1 you can simply take the gradient of the output
  • if the class is 0 you can take the gradient of -output

What do you think about this issue and this (possible) fix?

@EmanueleGhelfi Hi Emanuele! I think this would induce some complexity in the code for a particular case. Maybe you can switch your final Dense(1) layer into a Dense(2, activation='softmax')? Then you would be able to select class 1 or 2

I'm trying to avoid Exception to prevent trainings from breaking just because of a callback. Might add a warning though. Thanks!