histocartography/patho-quant-explainer

Ability to extend this to node classification?

Closed this issue · 1 comments

Hello!

Thanks a lot for makinig the histocartography package, and open-sourcing this code. Do you think it'll be possible to generalize this pipeline to node-classification? Currently working on interpretability of my model, and the generated histograms would be incredibly valuable, but it seems like it'll only work for graph classification. Any help on this? Really appreciate it!

Hi,

As you correctly pointed out, the framework is currently only able to explain graph classification tasks. In order to extend it to node classification, I would suggest the following steps:

  • Work with GraphGradCAM, other explainers will require more time and effort to be adapted
  • Start with this implementation: https://github.com/histocartography/histocartography/blob/main/histocartography/interpretability/grad_cam.py
  • Assuming you want to explain one node at the time:
    • you can declare your GraphGradCAM explainer in the same way as for graph classification tasks, ie by providing the model and the name of your GNN layers
    • the call function will still take a class idx, ie probably the predicted class, and the predicted logits for the query node
    • the _get_weights should still be the same (line 110), ie compute an importance score for each channel at each GNN layer
    • line 114 to 118 need to be adapted

Hope this helps