To use this repository, it is necessary to know PyTorch and PyTorch Lightning. Also, we recommend you to use one of the logging framework that offered by
Pytorch Lightning
, for example, Weights&Biases or Neptune.
This repository is an official implementation for Towards More Robust Interpretation via Local Gradient Alignment.
- Supported training losses: CE loss, Hessian regularizer, and l2+cosd regularizer
- Supported dataset: CIFAR10 and ImageNet100
- Supported model:
LeNet
andResNet18
We highly recommend you to use our conda environment.
# clone project
git clone https://github.com/joshua840/RobustAttributionGradAlignment.git
# install project
cd RobustAGA
conda env create -f environment.yml
conda activate agpu_env
Our directory structure looks like this:
├── project
│ ├── module <- Every modules are given in this directory
│ │ ├── lrp_module <- Modules to get LRP XAI are in this directory
│ │ ├── models <- Models
│ │ ├── utils <- utilized
│ │ ├── pl_classifier.py <- basic classifier
│ │ ├── pl_hessian_classifier.py <- hessian regularization
│ │ ├── pl_l2_plus_cosd_classifier.py <- l2 + cosd regularization
│ │ ├── test_adv_insertion.py <- run Adv-Insertion test
│ │ ├── test_insertion.py <- run Insertion test
│ │ ├── test_rps.py <- run Random Perturbation Similarity test
│ │ ├── test_taps_saps.py <- run adversarial attack test
│ │ └── test_upper_bouond.py <- run upper bound test
│ │
│ ├── main.py <- run train & test
│ └── test_main.py <- run advanced test codes
│
├── scripts <- Shell scripts
│
├── .gitignore <- List of files/folders ignored by git
├── environment.yml <- anaconda environment
└── README.md
You can check the arguments list by typing -h on CLI.
python project/main.py -h
usage: main.py [-h] [--seed SEED] [--regularizer REGULARIZER] [--loggername LOGGERNAME] [--project PROJECT] [--dataset DATASET] [--model MODEL]
[--activation_fn ACTIVATION_FN] [--softplus_beta SOFTPLUS_BETA] [--optimizer OPTIMIZER] [--weight_decay WEIGHT_DECAY]
[--learning_rate LEARNING_RATE] [--milestones MILESTONES [MILESTONES ...]] [--num_workers NUM_WORKERS] [--batch_size_train BATCH_SIZE_TRAIN]
[--batch_size_test BATCH_SIZE_TEST] [--data_dir DATA_DIR]
optional arguments:
-h, --help show this help message and exit
--seed SEED random seeds (default: 1234)
--regularizer REGULARIZER
A regularizer to be used (default: none)
--loggername LOGGERNAME
a name of logger to be used (default: default)
--project PROJECT a name of project to be used (default: default)
--dataset DATASET dataset to be loaded (default: cifar10)
Default classifier:
--model MODEL which model to be used (default: none)
--activation_fn ACTIVATION_FN
activation function of model (default: relu)
--softplus_beta SOFTPLUS_BETA
beta of softplus (default: 20.0)
--optimizer OPTIMIZER
optimizer name (default: adam)
--weight_decay WEIGHT_DECAY
weight decay for optimizer (default: 4e-05)
--learning_rate LEARNING_RATE
learning rate for optimizer (default: 0.001)
--milestones MILESTONES [MILESTONES ...]
lr scheduler (default: [100, 150])
Data arguments:
--num_workers NUM_WORKERS
number of workers (default: 4)
--batch_size_train BATCH_SIZE_TRAIN
batchsize of data loaders (default: 128)
--batch_size_test BATCH_SIZE_TEST
batchsize of data loaders (default: 100)
--data_dir DATA_DIR directory of cifar10 dataset (default: /data/cifar10)
In our code, the trainer module
is selected in here.
For each Lightning module class
, we defined add_model_specific_args
function, which requires additional arguments that used in that class.
By typing --regularizer option in CLI, you can also see these additional argument list.
python project/main.py --regularizer l2_cosd -h
l2_cosd arguments:
--eps EPS
--lamb LAMB
--lamb_c LAMB_C
--detach_source_grad DETACH_SOURCE_GRAD
python project/main.py --regularizer hessian -h
Hessian arguments:
--lamb LAMB
4.3 Hidden arguments
The Pytorch Lightning
offers useful argument list for training. For example, we used --max_epochs
and --default_root_dir
in our experiments. We strongly recommend you to refer to the following link to check the argument list.
We offer three options of loggers.
- Tensorboard (https://www.tensorflow.org/tensorboard)
- Log & model checkpoints are saved in
--default_root_dir
- Logging test code with Tensorboard is not available.
- Log & model checkpoints are saved in
- Weight & bias (https://wandb.ai/site)
- Create a new project on the WandB website.
- Specify the project argument
--project
- Neptune AI (https://neptune.ai/)
- Create a new project on the neptune website.
- export NEPTUNE_API_TOKEN="YOUR API TOKEN"
- export NEPTUNE_ID="YOUR ID"
- Set
--default_root_dir
asoutput/YOUR_PROJECT_NAME
Likewise, You can check the options for test code.
python project/test_main.py --test_method aopc -h
python project/test_main.py --test_method adv -h
python project/test_main.py --test_method adv_aopc -h
python project/test_main.py --test_method rps -h
python project/test_main.py --test_method upper_bound -h
For those above test codes, you should specify the --exp_id
argument. You can check the exp-id in your web project page and it seems like EXP1-1
for Neptune and 1skdq34
for WandB. Above runs will append the additional logs in to your projects.
This project is setup as a package which means you can now easily import any file into any other file like so:
from project.pl_classifier import LitClassifier
from project.module.utils.data_module import CIFAR10DataModule
from project.module.utils.data_module import ImageNet100DataModule
# Data
data_module = CIFAR10DataModule()
data_module = ImageNet100DataModule()
# Model
model = LitClassifier(model=model_name, activation_fn=activation_fn, softplus_beta=beta).cuda()
# train
trainer = Trainer()
trainer.fit(model, data_module)
# test using the best model!
trainer.test(model, data_module)
This project is setup as a package which means you can now easily import any file into any other file like so:
from project.test_upper_bound import LitClassifierUpperBoundTester as LitClassifier
from project.module.utils.data_module import CIFAR10DataModule, ImageNet100DataModule
from project.module.utils.interpreter import Interpreter
# Data
data_module = CIFAR10DataModule(dataset='cifar10',batch_size_test=10,data_dir = '../data/cifar10')
data_module.prepare_data()
data_module.setup()
test_loader = data_module.test_dataloader()
x_batch, y_batch = next(iter(test_loader))
x_s = x_batch.cuda().requires_grad_()
y_s = y_batch.cuda()
# Model
ckpt_path = f'YOUR_CHECKPOINT_PATH'
ckpt = torch.load(model_path)
args = ckpt['hyper_parameters']
model = LitClassifier(**args).cuda()
model.load_state_dict(ckpt['state_dict'])
model.eval()
# Use interpreter
yhat_s = model(x_s)
h_s = Interpreter(model).get_heatmap(x_s, y_s, yhat_s, "grad", 'standard', 'abs', False).detach()
@article{joo2022towards,
title={Towards More Robust Interpretation via Local Gradient Alignment},
author={Joo, Sunghwan and Jeong, Seokhyeon and Heo, Juyeon and Weller, Adrian and Moon, Taesup},
journal={arXiv preprint arXiv:2211.15900},
year={2022}
}
The citation of AAAI version is TBU.