- Authors
- Abstract
- How to explore this project
- Dataset
- Understanding Causal Variational AutoEncoder
- Sample Results
Farhanur Rahim Ansari, Gourang Patel, Sarang Pande, Vidhey Oza, Robert Ness
In this project we refractored the provided program for Causal Variational AutoEncoders such that there is a causal relationship between the latent variables as mentioned in the dSprites dataset. Once retained, we will apply various conditioning and interventions to elements of the program and have that generate a new image that reflects the applied intervention. We aim to apply Counterfactual (Twin World Concept) queries on our dataset and observe whether the reconstruction of CVAE works appropriately.
We use the data from the dSprites repository dSprites is a dataset of 2D shapes procedurally generated from 6 ground truth independent latent factors. These factors are color, shape, scale, rotation, x and y positions of a sprite. All possible combinations of these latents are present exactly once, generating N = 737280 total images. Latent factor values
- Color: white
- Shape: square, ellipse, heart
- Scale: 6 values linearly spaced in [0.5, 1]
- Orientation: 40 values in [0, 2 pi]
- Position X: 32 values in [0, 1]
- Position Y: 32 values in [0, 1]
We varied one latent at a time (starting from Position Y, then Position X, etc), and sequentially stored the images in fixed order. Hence the order along the first dimension is fixed and allows you to map back to the value of the latents corresponding to that image. We chose the latents values deliberately to have the smallest step changes while ensuring that all pixel outputs were different. No noise was added.
The data is a NPZ NumPy archive with the following fields:
-imgs: (737280 x 64 x 64, uint8) Images in black and white.
-latents_values: (737280 x 6, float64) Values of the latent factors.
-latents_classes: (737280 x 6, int64) Integer index of the latent factor values. Useful as classification targets.
-metadata: some additional information, including the possible latent values.
Alternatively, a HDF5 version is also available, containing the same data, packed as Groups and Datasets.
- Dimensionality reduction is the process of reducing the number of features that describe some data either by selecting only a subset of the initial features or by combining them into a reduced number new features. Hence they can be seen as an encoding problem too.
- Autoencoders are neural network architectures composed of an encoder and a decoder and trained to reconstruct the input during the encoding-decoding process of the model. As a result, the encoder learns to reduce dimensionality without losing important information about the input.
All the required dependencies are consolidated in requirements.txt
For installing all the dependencies run this line of code -
!pip install -r requirements.txt
This is the main Jupyter notebook that contains the full implementation of Causal VAE with counterfactuals.
The first section mainly deals with the setup of VAE as a supervised model. It loads the data from the dSprites repository. For error-free working, ensure that you specify the correct path after cloning the repo into the data
directory. The model is then trained and tested to verify its correct training. An alternative to manual training is to run the Load weights
cell.
The second section has the construction of the Structural Causal Model (SCM). To make sure the model was developed properly before performing causal operations, we run 2 sanity checks: generating single image and reconstructing it using sampling, and checking if the decoder is able to generate the image if the latents are changed.
Then we move on to perform three causal operations: conditioning, interventions and counterfactual reasoning.
To learn about Causal Variational AutoEncoder step by step, we have also included the Tutorials which includes View code & View pdf
The attached tutorials briefly explaning the working and functioning of the Causal Variational AutoEncoders. It also provides step-wise solution to various Counterfactual Queries applied on the Structured Causal Model.
The training has being done on Google Colab Platform on GPU resource. The dataset was divied into the train and test data in the data Once the CVAE class functions are set up we can execute the train and evaluate fucntion. The optimum learning rate used is 1.0e-3 and num of epochs are kept to be 10. The optimizer used here is "ADAM", as it works best with the stochastic dataset, which is here in our case. We observe from the elbo plot that the training losses with the given learning rate changes minimally after the 10 epochs. We also find the test loss after every 5 epochs i.e the TEST_EPOCH_FREQUECY is set to 5, so as to make sure that the model is not overfitting or underfitting our dataset.
Once the training is completed we are also saving the trained model weights so as to ensure the resusability of our results. The results observed our significant to implement the interventions and conditioning as we observed that the Average Training Loss after 10 epochs are 16.1449 and the Average Test loss After 5 epochs are 23.3984.
- The code is made compatible for GPU for faster processing.
- The learned weights are saved to avoid training frequently to enhance development efficiency.