The repository contains code for the reproduction of the results from "Variational Neural Cellular Automata" [1].
The autograd engine JAX, the neural network library equinox, the optimization library optax and the tensor operation library einops are used.
The results using the binarized MNIST dataset [2] are the main points of the paper reproduced.
To install requirements locally, run the following command:
pip install -r requirements.txt
To train a model using 8 v3 TPUs available on Kaggle, import the script main-train.py
as a kaggle notebook under:
- Create -> New Notebook -> File -> Import Notebook
Then select the TPU accelerator:
The script can then be run as a notebook.
To evaluate a trained model, the script to be used is eval.py
. The script should be loaded onto Kaggle in the same way as the training script.
Our model achieves the following performance on :
Model name | IWELBO evaluated on the test set using 128 importance weighted samples. |
---|---|
BaselineVAE | -84.64 nats |
DoublingVNCA | -84.15 nats |
NonDoublingVNCA | -89.3 nats |
The different figures from the paper can be reproduced using the following scripts:
Figure | Script to reproduce the figure |
---|---|
Figure 2 | sample.py |
Figure 3 | sample.py |
Figure 4 | latent_interpolate.py and t-sne.py |
Figure 5 | damage_recovery.py |
Figure 6 | linear_probe_figure.py |
Figure 7 | latent_viz.py |
To run a few tests that test the model output shape and the doubling operation, run the following command:
pytest tests.py
[1] R. B. Palm, M. G. Duque, S. Sudhakaran, and S. Risi. Variational Neural Cellular Automata. ICLR 2022
[2] H Larochelle and I Murray. The neural autoregressive distribution estimator.