/vnca

Code for the reproduction of the results from Variational Neural Cellular Automata

Primary LanguagePython

Re-implementation of Variational Neural Cellular Automata

The repository contains code for the reproduction of the results from "Variational Neural Cellular Automata" [1].

The autograd engine JAX, the neural network library equinox, the optimization library optax and the tensor operation library einops are used.

The results using the binarized MNIST dataset [2] are the main points of the paper reproduced.

drawing

Requirements

To install requirements locally, run the following command:

pip install -r requirements.txt

Training

To train a model using 8 v3 TPUs available on Kaggle, import the script main-train.py as a kaggle notebook under:

  • Create -> New Notebook -> File -> Import Notebook

drawing

Then select the TPU accelerator:

drawing

The script can then be run as a notebook.

Evaluation

To evaluate a trained model, the script to be used is eval.py. The script should be loaded onto Kaggle in the same way as the training script.

Results

Our model achieves the following performance on :

Model name IWELBO evaluated on the test set using 128 importance weighted samples.
BaselineVAE -84.64 nats
DoublingVNCA -84.15 nats
NonDoublingVNCA -89.3 nats

Figures

The different figures from the paper can be reproduced using the following scripts:

Figure Script to reproduce the figure
Figure 2 sample.py
Figure 3 sample.py
Figure 4 latent_interpolate.py and t-sne.py
Figure 5 damage_recovery.py
Figure 6 linear_probe_figure.py
Figure 7 latent_viz.py

Tests

To run a few tests that test the model output shape and the doubling operation, run the following command:

pytest tests.py

References

[1] R. B. Palm, M. G. Duque, S. Sudhakaran, and S. Risi. Variational Neural Cellular Automata. ICLR 2022

[2] H Larochelle and I Murray. The neural autoregressive distribution estimator.