/BIN

Tensorflow implementation of "Born Identity Network: Multi-way Counterfactual Map Generation to Explain a Classifier's Decision"

Primary LanguagePython

Born-Identity-Network

Tensorflow implementation of Born Identity Network: Multi-way Counterfactual Map Generation to Explain a Classifier's Decision.

Overall framework

  • The goal of Born-Identity-Netwok (BIN) is to induce counterfactual reasoning dependent on the target condition from a pre-trained model.
  • There are two major components of BIN: Counterfactual Map Generator (CMG) and Target Attribution Network (TAN).
  • The CMG synthesized a counterfactual map conditioned on arbitrary target label, while the TAN work towards enforcing target label attributes to the synthesized map.

Group 1387

Results

Counterfactual visual explanations

Group 1888 Group 1886

Extra interpolation using 3D Shapes

Group 1492

Requirements

tensorflow (2.2.0)
tensorboard (2.2.2)
tensorflow-addons (0.11.0)
tqdm (4.48.0)
matplotlib (3.3.0)
numpy (1.19.0)
scikit-learn (0.23.2)

Datasets

Place them into "data_path" on each Config.py

  1. HandWritten digits data (MNIST)
  2. 3D Geometric shape data
  3. Alzheimer’s Disease Neuroimaging Initiative (ADNI)

How to run

Mode:
#0 Pre-training a classifier
#1 Training the counterfactual map generator

  1. Pre-training a classifier
  • training.py --mode=0
  1. Training the counterfactual map generator
  • Set the classifier and encoder weight for training (freeze)
  • Change the mode from 0 to 1 on Config.py
  • training.py --mode=1

Config.py of each dataset

data_path = Raw dataset path
save_path = Storage path to save results such as tensorboard event files, model weights, etc.
cls_weight_path = Pre-trained classifier weight path obtained in mode#0 setup
enc_weight_path = Pre-trained encoder weight path obtained in mode#0 setup