/m-pax_lib

Primary LanguagePythonApache License 2.0Apache-2.0


Logo

Improving Explainability of Disentangled Representations using Multipath-Attribution Mappings

Read the paper Β»

πŸ”Ž  Table of Contents

πŸ“Œ  Introduction

Explainable AI aims to render model behavior understandable by humans, which can be seen as an intermediate step in extracting causal relations from correlative patterns. Due to the high risk of possible fatal decisions in image-based clinical diagnostics, it is necessary to integrate explainable AI into these safety-critical systems. Current explanatory methods typically assign attribution scores to pixel regions in the input image, indicating their importance for a model's decision. However, they fall short when explaining why a visual feature is used. We propose a framework that utilizes interpretable disentangled representations for downstream-task prediction. Through visualizing the disentangled representations, we enable experts to investigate possible causation effects by leveraging their domain knowledge. Additionally, we deploy a multi-path attribution mapping for enriching and validating explanations. We demonstrate the effectiveness of our approach on a synthetic benchmark suite and two medical datasets. We show that the framework not only acts as a catalyst for causal relation extraction but also enhances model robustness by enabling shortcut detection without the need for testing under distribution shifts.

πŸ—‚  Project Structure

β”œβ”€β”€ README.md                                
β”œβ”€β”€ LICENSE                             
β”œβ”€β”€ requirements.txt            - txt file with the environment
β”œβ”€β”€ run_eval.py                 - Main script to execute for evaluation
β”œβ”€β”€ run_head.py                 - Main script to execute for supervised training
β”œβ”€β”€ run_tcvae.py                - Main script to execute for unsupervised pre-training
β”œβ”€β”€ configs                     - Hydra configs
β”‚   β”œβ”€β”€ config_eval.yaml
β”‚   β”œβ”€β”€ config_head.yaml
β”‚   β”œβ”€β”€ config_tcvae.yaml
β”‚   β”œβ”€β”€ callbacks
β”‚   β”œβ”€β”€ datamodule
β”‚   β”œβ”€β”€ evaluation
β”‚   β”œβ”€β”€ experiment
β”‚   β”œβ”€β”€ hydra
β”‚   β”œβ”€β”€ logger
β”‚   β”œβ”€β”€ model
β”‚   └── trainer
β”œβ”€β”€ data                        - Data storage folders (each filled after first run)
β”‚   β”œβ”€β”€ DiagVibSix
β”‚   β”œβ”€β”€ ISIC
β”‚   β”œβ”€β”€ MNIST
β”‚   β”œβ”€β”€ models                  - Trained and saved models
β”‚   β”‚   └── dataset_beta        - Copied checkpoints per dataset and beta value
β”‚   β”‚       └── images          - Image export folder 
β”‚   └── OCT
β”œβ”€β”€ logs                        - Logs and Checkpoints saved per run and date
β”‚   └── runs
β”‚       └── date
β”‚           └── timestamp
β”‚               β”œβ”€β”€ checkpoints
β”‚               β”œβ”€β”€ .hydra
β”‚               └── tensorboard
└── src
    β”œβ”€β”€ evaluate.py             - Evaluation pipeline
    β”œβ”€β”€ train.py                - Training pipeline
    β”œβ”€β”€ datamodules             - Datamodules scripts
    β”œβ”€β”€ evaluation              - Evaluation scripts
    β”œβ”€β”€ models                  - Lightning modules
    └── utils                   - Various utility scripts (beta-TCVAE loss etc.)
                         

πŸš€  Usage

All essential libraries for the execution of the code are provided in the requirements.txt file from which a new environment can be created (Linux only). For the R script, please install the corresponding libraries beforehand. Setup package in a conda environment:

git clone https://github.com/IML-DKFZ/m-pax_lib
cd m-pax_lib
conda create -n m-pax_lib python=3.7
source activate m-pax_lib
pip install -r requirements.txt

Depending on your GPU, change the torch and torchvision version in the requirements.txt file to the respective CUDA supporting version. For CPU only support add trainer.gpus=0 behind every command.

Run the code

Once the virtual environment is activated, the code can be run as follows:

Running the scripts without any experiment files will start the training and evaluation on mnist. All parameters are defined in the hydra config files and not overwritten by any experiment files. The following commands will first, train the Ξ²-TCVAE loss based model with Ξ² = 4, second train the downstream classification head, and at last evaluate the model. The run_tcvae.py script also automatically initializes the download and extraction of the dataset at ./data/MNIST.

python run_tcvae.py
python run_head.py
python run_eval.py

Before training the head, place one of the encoder checkpoints (best or last epoch) from ./logs/runs/date/timestamps/checkpoints at ./models/mnist_beta=4 and rename them to encoder.ckpt. Folder can be renamed, but then has to be changed in the config/model/head_model.yaml and config/evaluation/default.yaml files. Place the head checkpoint in the same folder and rename it to head.ckpt. The evaluation script will create automatically an image folder inside, and export all graphics to this location.

Reproduce the results

For all other experiments in the paper, respective experiment files to overwrite the default parameters were created. The following configurations reproduce the results from the paper for each dataset. You can also add your own experiment yaml files or change the existing ones. For more information see here.

The ISIC and OCT evaluation need a rather large RAM size of ~80Gb. Reduce the batch size in the isic/oct_eval.yaml file to get less accurate but more RAM sparing results.

DiagViB-6

python run_tcvae.py +experiment=diagvibsix_tcvae.yaml
python run_head.py +experiment=diagvibsix_head.yaml
python run_eval.py +experiment=diagvibsix_eval.yaml seed=43

These commands run the experiment for the ZGO study. For the other two studies change ZGO to FGO_05 or FGO_20 in the three experiment files.

UCSD OCT Retina Scans

python run_tcvae.py +experiment=oct_tcvae.yaml
python run_head.py +experiment=oct_head.yaml
python run_eval.py +experiment=oct_eval.yaml seed=48

ISIC Skin Lesion Images

python run_tcvae.py +experiment=isic_tcvae.yaml
python run_head.py +experiment=isic_head.yaml
python run_eval.py +experiment=isic_eval.yaml seed=47

GIFs traversing the ten latent space features for five observations of each of the three datasets:

   

πŸ“  How to cite this code

Please cite the original publication:

@inproceedings{
  klein2022improving,
  title={Improving Explainability of Disentangled Representations using Multipath-Attribution Mappings},
  author={Lukas Klein and Jo{\~a}o B. S. Carvalho and Mennatallah El-Assady and Paolo Penna and Joachim M. Buhmann and Paul F Jaeger},
  booktitle={Medical Imaging with Deep Learning},
  year={2022},
  url={https://openreview.net/forum?id=3uQ2Z0MhnoE}
}

Acknowledgements

The code is developed by the authors of the paper. However, it does also contain pieces of code from the following packages:




              

The m-pax_lib is developed and maintained by the Interactive Machine Learning Group of Helmholtz Imaging and the DKFZ, as well as the Information Science and Engineering Group at ETH ZΓΌrich.