This is the official code implementation for the "Sparse Linear Concept Discovery Models" paper, accepted (Oral) @ ICCVW-CLVL 2023.
We propose a novel framework towards Interpretable Deep Networks using multi-modal models and a simple yet effective concept selection mechanism.
The recent mass adoption of DNNs, even in safety-critical scenarios, has shifted the focus of the research community towards the creation of inherently intrepretable models. Concept Bottleneck Models (CBMs) constitute a popular approach where hidden layers are tied to human understandable concepts allowing for investigation and correction of the network's decisions. However, CBMs usually suffer from: (i) performance degradation and (ii) lower interpretability than intended due to the sheer amount of concepts contributing to each decision. In this work, we propose a simple yet highly intuitive interpretable framework based on Contrastive Language Image models and a single sparse linear layer. In stark contrast to related approaches, the sparsity in our framework is achieved via principled Bayesian arguments by inferring concept presence via a data-driven Bernoulli distribution. As we experimentally show, our framework not only outperforms recent CBM approaches accuracy-wise, but it also yields high per example concept sparsity, facilitating the individual investigation of the emerging concepts.
The file structure to make sure that everything works as intended is the following:
── CDM
│ ├── clip/
│ ├── data/
│ ├── saved_models/
│ ├── main.py
│ ├── data_utils.py
│ ├── networks.py
│ ├── utils.py
│ └── README.md
where the saved_models
folder will be created automatically if it doesn't already exist when running the main script.
- Create a venv/conda environment containing all the necessary packages. This can be achieved using the provided .yml file.
- Specifically, run
conda env create -f clip_env.yml
.
When considering CUB and ImageNet, you should set it up with the standard format and provide the correct path in the data_utils.py
file in the corresponding ImageNet entry.
As described in the main text, the models are trained using the embeddings arising from a pretrained clip model. To facilitate training and inference speeds, we first embed the dataset in the CLIP embedding space and use then load the embedded vectors as the dataset to be used. For small dataset like CIFAR-10, CIFAR-100 and CUB, this is an easy task. It should only take a couple of minutes to produce and save the embeddings for all images in these datasets. ImageNet and Places365 take a lot more, considering the highly increased number of images. For reproducibility and further development, we provide our image features in the following links:
For saving the text embeddings of a different dataset, one should use the following command:
python main.py --dataset cifar10 --compute_similarities
where you replace the dataset
argument with the name of your dataset.
For this to work, you need to implement data loding function in the data_utils.py
file.
This assumes that you use the default concept set, i.e., cifar100. To use a different concept set
(even your own), specify the name in the concept_name
argument, and make sure that
your concept file is in the correct folder, i.e., data/concept_sets/your_concept_set.txt
.
As an example, for using the ImageNet concept set, one could use the following command:
python main.py --dataset cifar100 --concept_name imagenet --compute_similarities
Assuming that you have the embeddings already computed, you can train the linear layers from scratch on a given dataset. To train the network on cifar100 with the cifar100 concept set for 1000 epochs, the command is:
python main.py --dataset cifar100 --load_similarities --concept_name cifar100 --epochs 1000
For the evaluation of a pretrained model, the command is:
python main.py --eval --dataset cifar100 --concept_name cifar100 --load_similarities --ckpt /path/to/your/checkpoint.pth.tar
Dataset | |||||
---|---|---|---|---|---|
Model | CIFAR10 | CIFAR100 | CUB200 | Places365 | ImageNet |
Standard | 88.80% | 70.10% | 76.70% | 48.56% | 76.13% |
Standard (sparse) | 82.96% | 58.34% | 75.96% | 38.46% | 74.35% |
Label-free CBM | 86.37% | 65.27% | 74.59% | 43.71% | 71.98% |
CDM (RN50, w/o Z) | 81.90% | -- | 63.40% | -- | 64.70% | -- | 52.90% | -- | 71.20% | -- |
CDM (RN50, w/ Z) | 86.50% | 2.55 | 67.60% | 9.30 | 72.26% | 21.3 | 52.70% | 8.28 | 72.20% | 8.53 |
CDM (ViT-B/16, w/o Z) | 94.45% | -- | 79.00% | -- | 75.10% | -- | 54.40% | -- | 77.90% | -- |
CDM (ViT-B/16, w/ Z) | 95.30% | 1.69 | 80.50% | 3.38 | 79.50% | 13.4 | 52.58% | 8.00 | 79.30% | 6.96 |