Official implementation for paper Schema Inference for Interpretable Image Classification
Authors: Haofei Zhang, Mengqi Xue, Xiaokang Liu, Kaixuan Chen, Jie Song, Mingli Song
- CIFAR: download CIFAR-10/100 dataset to folder
~/datasets/cifar
(you may specify this in configuration files). - Caltech-101: download Caltech-101 to folder
~/datasets/caltech-101
(you need to manually split it into training and test datasets). - ImageNet: download ImageNet dataset to folder
~/datasets/ILSVRC2012
and pre-process with this script.
Please refer to the particular dataset implementation in this url for more details.
Our code requires cv-lib-PyTorch. You should download this repo and checkout to tag schema_inference
.
cv-lib-PyTorch
is an open source repo currently maintained by me.
torch==1.12.1+cu113
torchvision==0.13.1+cu113
tqdm
tensorboard
scipy
PyYAML
pandas
numpy
graphviz
h5py
matplotlib
networkx
pandas
scikit-learn
sklearn
seaborn
Backbone | Url |
---|---|
DeiT-Tiny | ckpt |
DeiT-Small | ckpt |
DeiT-Base | ckpt |
These pre-trained weights are modified from DeiT Official Repo with changing the submodule names for our code. Please download required weights to CODE_DIR/weights/
sh 0.build.sh
sh 0.train_backbone.sh
sh 1.extract_ingredients.sh
sh 2.save_backbone_jit.sh
sh 3.init_schema_net.sh
sh 4.train_schema_net.sh
If you found this work useful for your research, please cite our paper:
@inproceedings{
zhang2023schema,
title={Schema Inference for Interpretable Image Classification},
author={Haofei Zhang and Mengqi Xue and Xiaokang Liu and Kaixuan Chen and Jie Song and Mingli Song},
booktitle={The Eleventh International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=VGI9dSmTgPF}
}