sicara/tf-explain

Occlusion sensitivity doesn't give output

rao208 opened this issue · 8 comments

Hello folks,

It is not really a bug in tf_explain, but I am reaching out to the tf_explain community here because I really don't know who else to reach.

I am working on Synthetic MNIST dataset with image size = (64,64,3). The images are downloaded from Kaggle. These images were brightened and then sharpen (train, test and validation) before normalizing it (i.e. /255).

Original Image
images

Final output

sharpen_image

Since the dataset doesn't follow the Gaussian distribution (used np.histogram to view the distribution of images), I avoided Standardization i.e. subtract mean and divide by standard deviation.

My CNN looks like this

sequntial_model

and the results are:

cm_plot_acc

plot_acc_loss

However, when I apply Occlusion Sensitivity on the data with patch size 20, I am not getting the expected output. What I mean is, when I apply os on 3 samples from class 0, let's say, then the heatmaps are on the same site.

test_data = ([x_test[sampleid]], None)

# Instantiation of the explainer
explainer = OcclusionSensitivity()

# Call to explain() method
output = explainer.explain(test_data, model, class_index = classidx, patch_size = 20, colormap = cv2.COLORMAP_JET)

class_0_1
class_0_2
class_0_3

It is true for all the classes. Does this mean that my CNN is not learning anything? It cannot be true because when I apply GradCAM, there is an output i.e. different heatmap location for different images from the same class. Or does this mean that this is the correct output? If so, then does this make sense to get the heatmaps on the same location on different samples of the same class?

Any help would be appreciated. Please help me because it is very important for my thesis. I have spent almost a month to figure this out.

If you need any further information, let me know

Best regards.

@rao208 You might want to reduce the patch size: a patch size of 20 means you only apply 3 patches along the x axis. Among the 9 patches, the bottom-right might be giving less information, hence the global red colormap. You might want to use a patch size of 5, or different patch size values (e.g [2, 5, 10]) to be able to compare multiple attribution maps.

@rao208 You might want to reduce the patch size: a patch size of 20 means you only apply 3 patches along the x axis. Among the 9 patches, the bottom-right might be giving less information, hence the global red colormap. You might want to use a patch size of 5, or different patch size values (e.g [2, 5, 10]) to be able to compare multiple attribution maps.

Thank you for the quick response @RaphaelMeudec. Even with the different patch sizes, the heatmap location is the same. Here are the results:

Patch size 5

class_0_3_patch_size5
class_0_1_patch_size5
class_0_2_patch_size5

Patch Size 10

class_0_1_patch_size10
class_0_2_patch_size10
class_0_3_patch_size10

Is there any problem with how the test_data is given to the occlusion sensitivity? (The code is attached in the question above)

I just can't figure out what could be the cause. I worked with cifar10 as well and I see the similar pattern there too i.e. i.e. different heatmap location for different images from the same class.

Could you provide the link to the dataset and the training script? (in particular the preprocessing you apply to the images)

@RaphaelMeudec
The link to the dataset in Kaggle is https://www.kaggle.com/prasunroy/synthetic-digits?

There are two folders: imgs_train and imgs_valid each of these contains 10 more folders. I, first, put all the training digits in the 'train' folder and testing digits in the 'test' folder. Later, converted them into .npy file. I tried to convert the folder into the .zip file, but it was too big to upload here. Nevertheless, you can access the .npy file from my drive (https://drive.google.com/drive/u/2/folders/1rjQ0CjaiiNcuHhXJptlvs9u-Fog6LAoW).

Please let me know if you are unable to open the link or download the files

The code is

# -*- coding: utf-8 -*-
"""
Created on Mon May 11 16:31:50 2020

@author: Vanditha Rao
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, BatchNormalization
from sklearn.metrics import classification_report, confusion_matrix
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import optimizers
from sklearn.model_selection import train_test_split
from skimage import exposure
from PIL import Image
from PIL import ImageEnhance
from tensorflow.keras import regularizers
tf.keras.backend.clear_session()

with tf.device('/device:GPU:0'):
    
    class synthetic_mnist_model:
        
        def __init__(self):
            
            self.x_shape = [64,64,3]
            self.batch_size = 64
            self.maxepoches = 50
            self.num_classes = 10
            self.weight_decay = 0.0005
            self.model = self.build_model()

        def plot_confusion_matrix(self, y_true, y_pred, classes, title = None, cmap=plt.cm.Blues):
                
    
            """
            This function prints and plots the confusion matrix.
            Normalization can be applied by setting `normalize=True`.
            """
    
            # Compute confusion matrix
    
            cm = confusion_matrix(y_true, y_pred)
            
            # Only use the labels that appear in the data
    
            fig, ax = plt.subplots(figsize=(10,10))
            
            im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
            ax.figure.colorbar(im, ax=ax)
            
            # We want to show all ticks...
            
            ax.set(xticks=np.arange(cm.shape[1]),
                    yticks=np.arange(cm.shape[0]),
                    # ... and label them with the respective list entries
                    xticklabels=classes, yticklabels=classes,
                    title=title,
                    ylabel='True label',
                    xlabel='Predicted label')
    
            # Rotate the tick labels and set their alignment.
            plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
                      rotation_mode="anchor")
    
            # Loop over data dimensions and create text annotations.
            fmt = 'd'     
            thresh = cm.max() / 2.0
            
            for i in range(cm.shape[0]):
                for j in range(cm.shape[1]):
                    ax.text(j, i, format(cm[i, j], fmt),
                            ha="center", va="center",
                            color="white" if cm[i, j] > thresh else "black")
                 
            ax.set_ylim(len(cm)-0.5, -0.5)
    
            fig.tight_layout()
            return ax
        
        def load_data(self):
        
            test_x = np.load('./data/test_x_64.npy')
            test_y = np.load('./data/test_y_64.npy')
            
            train_x = np.load('./data/train_x_64.npy')
            train_y = np.load('./data/train_y_64.npy')
            
            train_x = train_x.astype('float32')
            test_x = test_x.astype('float32')
        
            return test_x, test_y, train_x, train_y
        
        
        
        def change_brightness(self,img, brightness = 1.2):
    
            enh_bri = ImageEnhance.Brightness(img)
            image_brightened = enh_bri.enhance(brightness)
            
            return image_brightened
        
        def brightness(self, X_train, X_test, X_val):
            
            bright_train = np.zeros(X_train.shape)
            bright_test = np.zeros(X_test.shape)
            bright_val = np.zeros(X_val.shape)

            for i in range(X_train.shape[0]):
                image = Image.fromarray(X_train[i, :, :, :].astype(np.uint8))
                bright_train[i, :, :, :] = self.change_brightness(image)
            
            bright_train = bright_train.astype('float32')
            
            for i in range(X_test.shape[0]):
                image = Image.fromarray(X_test[i, :, :, :].astype(np.uint8))
                bright_test[i, :, :, :] = self.change_brightness(image)
            
            bright_test = bright_test.astype('float32')
            
            for i in range(X_val.shape[0]):
                image = Image.fromarray(X_val[i, :, :, :].astype(np.uint8))
                bright_val[i, :, :, :] = self.change_brightness(image)
            
            bright_val = bright_val.astype('float32')
            
            return bright_train, bright_test, bright_val

        
        def change_sharpness(self,img, sharpness = 2.0):
            enh_sha = ImageEnhance.Sharpness(img)
            image_sharped = enh_sha.enhance(sharpness)
    
            return image_sharped
        
        def sharpness(self, X_train, X_test, X_val):
            
            sharp_train = np.zeros(X_train.shape)
            sharp_test = np.zeros(X_test.shape)
            sharp_val = np.zeros(X_val.shape)

            for i in range(X_train.shape[0]):
                image = Image.fromarray(X_train[i, :, :, :].astype(np.uint8))
                sharp_train[i, :, :, :] = self.change_sharpness(image)
            
            sharp_train = sharp_train.astype('float32')
            
            for i in range(X_test.shape[0]):
                image = Image.fromarray(X_test[i, :, :, :].astype(np.uint8))
                sharp_test[i, :, :, :] = self.change_sharpness(image)
            
            sharp_test = sharp_test.astype('float32')
            
            for i in range(X_val.shape[0]):
                image = Image.fromarray(X_val[i, :, :, :].astype(np.uint8))
                sharp_val[i, :, :, :] = self.change_sharpness(image)
            
            sharp_val = sharp_val.astype('float32')
            
            
            return sharp_train, sharp_test, sharp_val
               
        
        def one_hot_encode(self, Y_train, Y_test, Y_val):
            
            Y_train = tf.keras.utils.to_categorical(Y_train, self.num_classes)
            Y_test = tf.keras.utils.to_categorical(Y_test, self.num_classes)
            Y_val = tf.keras.utils.to_categorical(Y_val, self.num_classes)
            
            return Y_train, Y_test, Y_val
        
        def predict(self, x):
            return self.model.predict(x, self.batch_size)
        
        def evaluate(self,x,y):
            return self.model.evaluate(x,y, self.batch_size, verbose=2)
        
        def build_model(self):
            
            model = Sequential()
            
            weight_decay = 0.001

            model.add(Conv2D(16, kernel_size= (4,4),
                             input_shape = self.x_shape,
                             padding = 'same',
                             activation = "relu",
                             # kernel_initializer='he_normal'
                             kernel_regularizer=regularizers.l2(weight_decay)
                             ))
            
            model.add(BatchNormalization())
            
            model.add(Conv2D(16, kernel_size= (4,4),
                             padding = 'same',
                             # activation= tf.nn.leaky_relu,))
                             activation = "relu",
                             # kernel_initializer='he_normal'
                             kernel_regularizer=regularizers.l2(weight_decay)
                             ))
            
            model.add(BatchNormalization())
            model.add(MaxPooling2D(pool_size=(2,2)))
            # model.add(Dropout(0.2))
            
            model.add(Conv2D(32, kernel_size=(3,3),
                             padding = 'same',
                             activation = "relu",
                             # kernel_initializer='he_normal'
                             kernel_regularizer=regularizers.l2(weight_decay)
                             ))
            model.add(BatchNormalization())
            
            model.add(Conv2D(32, kernel_size=(3,3),
                             padding = 'same',
                             activation = "relu",
                             # kernel_initializer='he_normal'))
                             kernel_regularizer=regularizers.l2(weight_decay)))
            model.add(BatchNormalization())
            model.add(MaxPooling2D(pool_size=(2,2)))
            # model.add(Dropout(0.3))
            
            model.add(Conv2D(64, kernel_size=(3,3),
                              padding = 'same',
                              activation = "relu",
                              kernel_regularizer=regularizers.l2(weight_decay)
                              ))
            model.add(BatchNormalization())
            
            model.add(Conv2D(64, kernel_size=(3,3),
                              padding = 'same',
                              activation = "relu",
                              kernel_regularizer=regularizers.l2(weight_decay)))
            
            model.add(BatchNormalization())
            model.add(MaxPooling2D(pool_size=(2,2)))
            
            model.add(Conv2D(128, kernel_size=(3,3),
                              padding = 'same',
                              activation = "relu",
                              kernel_regularizer=regularizers.l2(weight_decay)
                              ))
            model.add(BatchNormalization())
            
            model.add(Conv2D(128, kernel_size=(3,3),
                              padding = 'same',
                              activation = "relu",
                              kernel_regularizer=regularizers.l2(weight_decay)))
            
            model.add(BatchNormalization())
            model.add(MaxPooling2D(pool_size=(2,2)))
            
            
            model.add(Flatten())
                    
            model.add(Dense(1152, activation = "relu",
                            # kernel_regularizer=regularizers.l2(weight_decay)
                            ))
            
            model.add(Dropout(0.5))
            model.add(Dense(1152, activation = "relu",
                            # kernel_regularizer=regularizers.l2(weight_decay)
                            ))
            
            model.add(Dropout(0.5))
            
            model.add(Dense(10, activation= "softmax"))
            model.summary()
            
            return model
        
        def train(self, train_x, train_y, val_x, val_y):
            
            batch_size = 64

            datagen = ImageDataGenerator(rotation_range=15,
                                         width_shift_range=0.1,
                                         height_shift_range=0.1,
                                         horizontal_flip=True,
                                         )
            
            datagen.fit(train_x)
   
            self.model.compile(loss='categorical_crossentropy',
                               optimizer=optimizers.Adadelta(lr=1),
                               metrics=['accuracy'])
            
            history = self.model.fit(datagen.flow(train_x, train_y, batch_size=batch_size),
                                steps_per_epoch=train_x.shape[0] // batch_size,
                                epochs=self.maxepoches,
                                validation_data=(val_x, val_y),
                                verbose=2)

            
            return history    
        
        def save(self):
            
            model_json = self.model.to_json()
        
            self.model.save_weights('./model/model_weights_synthetic_mnist_os_bright_sharp_gap.h5')
            self.model.save('./model/model_synthetic_mnist_os_bright_sharp_gap.h5')
            
            with open('./model/model_synthetic_mnist_os_bright_sharp_gap.json', 'w') as json_file:
                json_file.write(model_json)

                
        def plot(self, history):
            
            train_acc = history.history['accuracy']
            validation_acc = history.history['val_accuracy']
        
            train_loss = history.history['loss']
            validation_loss = history.history['val_loss']
        
            plt.figure(figsize=(10,10))
        
            plt.subplot(1,2,1)          
            plt.plot(train_acc,'r',label='Training Accuracy')
            plt.plot(validation_acc,'b',label='Validation Accuracy')
        
            plt.title('Training and Validation Accuracy')
            plt.legend()
        
            plt.subplot(1,2,2)
            
            plt.plot(train_loss,'r',label='Training loss')
            plt.plot(validation_loss,'b',label='Validation loss')
        
            plt.title('Training and Validation loss')
            plt.legend()
            plt.show()
            

    if __name__ == '__main__':
           
        
        sm = synthetic_mnist_model()
        
        # load the dataset
        
        test_x, test_y, train_x, train_y = sm.load_data()
        
        # split the training set into training and validation
        
        train_x, val_x, train_y, val_y = train_test_split(train_x, train_y,
                                                          test_size=0.2,
                                                          random_state=1234,
                                                          shuffle = True,
                                                          stratify=train_y
                                                          )
        print(train_x.shape)
        print(test_x.shape)
        print(val_x.shape)
        
        # Plot the training images
    
        fig = plt.figure(figsize=(5,4))

        
        for i in range(3):
            for j in range(3):
                ax = fig.add_subplot(3, 3, i * 3 + j + 1)
                ax.imshow(train_x[i * 3 + j]/255)
                
        plt.show()
        
        # brighten the images
        
        train_x, test_x, val_x = sm.brightness(train_x, test_x, val_x)
        
        # plot brighten images
        
        fig1 = plt.figure(figsize=(5,4))
        for i in range(3):
            for j in range(3):
                ax1 = fig1.add_subplot(3, 3, i * 3 + j + 1)
                ax1.imshow(train_x[i * 3 + j]/255)
                
        plt.show()    
        
        # sharpen the images 
        
        train_x, test_x, val_x = sm.sharpness(train_x, test_x, val_x)
        
        # normalize the image
        
        train_x /=255
        test_x /=255
        val_x /=255
        
        # view sharp images 
        fig2 = plt.figure(figsize=(5,4))
        for i in range(3):
            for j in range(3):
            
                ax2 = fig2.add_subplot(3, 3, i * 3 + j + 1)
                ax2.imshow(train_x[i * 3 + j])
                
        plt.show()
        
        # one hot encode
            
        train_y, test_y, val_y = sm.one_hot_encode(train_y, test_y, val_y)
                
        # train the model
        history = sm.train(train_x, train_y, val_x, val_y)
                
        # save the model
        
        sm.save()
        
        # Plot accuracy and loss
        
        sm.plot(history)
        
        
        # predict
        
        predicted_x = sm.predict(test_x)
        residuals = np.argmax(predicted_x,1)!=np.argmax(test_y,1)
    
        loss = sum(residuals)/len(residuals)
        print("the validation 0/1 loss is: ",loss)
        
        # evaluate on test dataset
        
        loss, acc = sm.evaluate(test_x, test_y)
        print('Test Accuracy: %.3f' % (acc * 100))
        
        
        # plot confusion matrix and print classification report
        classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        labels = [0,1,2,3,4,5,6,7,8,9]
        
        test_y =  test_y.argmax(axis=1)
        y_pred = predicted_x.argmax(axis=1)
        
        print(classification_report(test_y, y_pred, target_names= classes))
        
        np.set_printoptions(precision=2)

        ## Plot non-normalized confusion matrix
        
        sm.plot_confusion_matrix(test_y, y_pred, classes=classes, title='Confusion matrix', cmap=plt.cm.Blues)
        plt.show()

@RaphaelMeudec

Update:

I have observed a similar pattern when I use Albumentation image augmentation technique.

@RaphaelMeudec Please help me. Do you think there is any bug in my code? I tried different preprocessing techniques (like albumentation data augmentation, standardization, normalization, brightening and sharpening the image)

@RaphaelMeudec What is the use of grid_display function? I went through your code on occlusion sensitivity. I get the use of everything except for the grid_display function. I was wondering what is the significance of that function? What if we do not use that function?

@RaphaelMeudec What is the use of grid_display function? I went through your code on occlusion sensitivity. I get the use of everything except for the grid_display function. I was wondering what is the significance of that function? What if we do not use that function?

@rao208 I explain a little bit about grid_display here and how you can not use it. Hope it helps you.