PyTorch Implementation of No Token Left Behind: Explainability-Aided Image Classification and Generation
First, follow DATASETS.md to install the datasets. Create the required enviromnet with
conda env create -f external/CoOp/dassl_env.yml
conda activate dassl
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
Then clone and install dassl under 'external' direrctory:
cd external/Dassl.pytorch/
python setup.py develop
cd ../../
To run the experiment please run:
python external/CoOp/train.py --root <dataset_root> --trainer CoOp --dataset-config-file <dataset config file> --config-file external/CoOp/configs/trainers/CoOp/<base model>_ep50.yaml --output-dir <output_dir> --model-dir <model_dir> --seed 1 DATASET.NUM_SHOTS 1 TRAINER.COOP.EXPL_WEIGHT <expl_lambda> TRAINER.COOP.CSC False TRAINER.COOP.RETURN_EXPL_SCORE True TRAINER.COOP.CLASS_TOKEN_POSITION middle TRAINER.COOP.N_CTX 16
@misc{Paiss2022NoTL,
url = {https://arxiv.org/abs/2204.04908},
author = {Paiss, Roni and Chefer, Hila and Wolf, Lior},
title = {No Token Left Behind: Explainability-Aided Image Classification and Generation},
publisher = {arXiv},
year = {2022}
}
- Image manipulation code is based on StyleCLIP
- Image generation code is based on FuseDream
- Image generation with spatial conditioning code is based on VQGAN+CLIP and VQGAN
- Prompt engineering code is based on CoOp and Dassl
- Explainability method code is based on Transformer-MM-Explainability
This sample code is released under the LICENSE terms.