Official repository for the CT-VAE model introduced in the article "Disentanglement of Latent Representations via Sparse Causal Interventions" [Gendron, Witbrock, and Dobbie; 2023]. This repository is based on a fork of the pytorch-VAE library with two new models: the MCQ-VAE and the CT-VAE, and additional controls over the datasets and experiments.
The repository contains several datasets for reconstruction tasks:
The CelebA dataset can be downloaded here while the other disentanglement datasets are obtained using the disent library. For installing CelebA, please follow the recommendations from the original Pytorch-VAE repository.
Variants of these datasets for Causal Transition tasks are available:
- TCeleba
- TSprites
- TShapes3D
- TSmallNORB
- TDSprites
- TCars3D
The transition datasets contain pairs of images such that the value of only one label changes between the two images. If the label is categorical, then a transition can happen only between two adjacent values.
To create the transition datasets, run the following scripts:
$ python utils/celeba_variation_gen.py
$ python utils/disent_variation_gen.py <dataset_name>
$ git clone https://github.com/Strong-AI-Lab/ct-vae
$ cd ct-vae
$ pip install -r requirements.txt
To run experiments, use the following lines of code:
$ cd PyTorch-VAE
$ python run.py -c configs/<config-file-name.yaml>
model_params:
name: "<name of VAE model>"
in_channels: <number of channels in the image, e.g. 3 for colour images and 1 for B&W>
. # Other parameters required by the model
.
.
data_params:
data_path: "<path to the dataset storage, 'Data/' by default>"
dataset_name: "<name of the dataset>"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
limit: <restriction to the size of the dataset, avalaible for TDatasets only>
distributed: <True if using several GPUs and False otherwise, this parameter is needed for TDatasets only>
exp_params:
manual_seed: 1265
LR: 0.005
find_unused_parameters: <True if model does not train all its parameters during a forward pass, False otherwise>
update_parameters: <subset of parameters to train, if specified, freezes the training of all other parameters of the model>
. # Other arguments required for training, like scheduler etc.
.
.
trainer_params:
gpus: 1
max_epochs: 100
gradient_clip_val: 1.5
resume_from_checkpoint: "<optional, path to the model checkpoint to to load the model from>"
load_weights_only: <use only if 'resume_from_checkpoint' is specified, if True, will not load the state of the optimizers>
.
.
.
logging_params:
save_dir: "logs/"
name: "<experiment name>"
Tensorboard logs can be accessed here:
$ cd logs/<experiment name>/version_<the version you want>
The experiments also store logs with wandb.
This repository allows hyperparameter search using ray tune:
$ cd PyTorch-VAE
$ python hyperparameter_search.py -c configs_hyp/<config-file-name.yaml>
If you use this repository, please cite our work:
@article{DBLP:journals/corr/abs-2302-00869,
author = {Ga{\"{e}}l Gendron and
Michael Witbrock and
Gillian Dobbie},
title = {Disentanglement of Latent Representations via Sparse Causal Interventions},
journal = {CoRR},
volume = {abs/2302.00869},
year = {2023},
url = {https://doi.org/10.48550/arXiv.2302.00869},
doi = {10.48550/arXiv.2302.00869},
eprinttype = {arXiv},
eprint = {2302.00869},
timestamp = {Thu, 09 Feb 2023 16:11:17 +0100},
biburl = {https://dblp.org/rec/journals/corr/abs-2302-00869.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}