How to apply this to any classifier?
adebayoj opened this issue · 4 comments
Hi Thank you for this contribution.
I am wondering if you have example code for how to apply this to classifiers that are not resnet-50, inception-v3, or vgg16? I am hoping to get excitation backprop figures for a 3 layer cnn and a 3 layer fully connected network.
Thanks for the great work!
For gradient-based methods (which includes excitation-bp), they replace every layers with customized layers. So in general, they can apply to any CNN models like this
model = YourModel()
explainer = excitation_backprop(
model, output_layer_keys=['xxx'] # The end backprop layer key in your model
)
Thanks for the response. Will this also work for the realtime saliency method?
No. Real Time Saliency is not gradient-based method. It trains an auxiliary network based on ResNet-50. But you can follow their objective and learning algorithm to train your own explainer network.
Great thanks.