Official repository for the CT-VAE model introduced in the article "Disentanglement of Latent Representations via Sparse Causal Interventions" [Gendron, Witbrock, and Dobbie; 2023]. This repository is based on a fork of the pytorch-VAE library with two new models: the MCQ-VAE and the CT-VAE, and additional controls over the datasets and experiments.
The repository contains several datasets for reconstruction tasks:
The CelebA dataset can be downloaded here while the other disentanglement datasets are obtained using the disent library. For installing CelebA, please follow the recommendations from the original Pytorch-VAE repository.
Variants of these datasets for Causal Transition tasks are available:
- TCeleba
- TSprites
- TShapes3D
- TSmallNORB
- TDSprites
- TCars3D
The transition datasets contain pairs of images such that the value of only one label changes between the two images. If the label is categorical, then a transition can happen only between two adjacent values.
To create the transition datasets, run the following scripts:
$ python utils/
$ python utils/ <dataset_name>
$ git clone
$ cd ct-vae
$ pip install -r requirements.txt
To run experiments, use the following lines of code:
$ cd PyTorch-VAE
$ python -c configs/<config-file-name.yaml>
name: "<name of VAE model>"
in_channels: <number of channels in the image, e.g. 3 for colour images and 1 for B&W>
. # Other parameters required by the model
data_path: "<path to the dataset storage, 'Data/' by default>"
dataset_name: "<name of the dataset>"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
limit: <restriction to the size of the dataset, avalaible for TDatasets only>
distributed: <True if using several GPUs and False otherwise, this parameter is needed for TDatasets only>
manual_seed: 1265
LR: 0.005
find_unused_parameters: <True if model does not train all its parameters during a forward pass, False otherwise>
update_parameters: <subset of parameters to train, if specified, freezes the training of all other parameters of the model>
. # Other arguments required for training, like scheduler etc.
gpus: 1
max_epochs: 100
gradient_clip_val: 1.5
resume_from_checkpoint: "<optional, path to the model checkpoint to to load the model from>"
load_weights_only: <use only if 'resume_from_checkpoint' is specified, if True, will not load the state of the optimizers>
save_dir: "logs/"
name: "<experiment name>"
Tensorboard logs can be accessed here:
$ cd logs/<experiment name>/version_<the version you want>
The experiments also store logs with wandb.
This repository allows hyperparameter search using ray tune:
$ cd PyTorch-VAE
$ python -c configs_hyp/<config-file-name.yaml>
