This is a reference implementation for Group Equivariant Subsampling by Jin Xu, Hyunjik Kim, Tom Rainforth and Yee Whye Teh.
See environment.yml
. For Anaconda users, please create conda environment with conda env create -f environment.yml
For dSprites and FashionMNIST, data will be automatically downloaded and preprocessed before your first run.
For multi-object datatsets such as Multi-dSprites, please first run
python multi_object_datasets/load.py --dataset multi_dsprites --datadir "/tmp"
Train ConvAE
, GConvAE-p4
, GConvAE-p4m
, GAE-p1
, GAE-p4
, GAE-p4m
on dSprites
:
python main.py hydra.job.name=sample_complexity model=conv_ae run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
python main.py hydra.job.name=sample_complexity model=gconv_ae model.n_channels=21 model.fiber_group='rot_2d' model.n_rot=4 run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
python main.py hydra.job.name=sample_complexity model=gconv_ae model.n_channels=15 model.fiber_group='flip_rot_2d' model.n_rot=4 run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
python main.py hydra.job.name=sample_complexity model=eqv_ae run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
python main.py hydra.job.name=sample_complexity model=eqv_ae model.n_channels=26 model.fiber_group='rot_2d' model.n_rot=4 run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
python main.py hydra.job.name=sample_complexity model=eqv_ae model.n_channels=18 model.fiber_group='flip_rot_2d' model.n_rot=4 model.n_rot=4 run.mode=train data=dsprites data.train_set_size=1600 run.random_seed=1
The numbers of channels are rescaled so that the above models have similar number of parameters. To train on FashionMNIST
, one can simply set data=fashion_mnist
. To show the progress bar during training, set run.use_prog_bar=True
.
To visualise image reconstructions, set run.mode=reconstruct
. For example, for GAE-p1
on dSprites
,
python main.py hydra.job.name=sample_complexity model=eqv_ae run.mode=reconstruct data=dsprites data.train_set_size=1600 run.random_seed=1
To evaluate the trained model, set set run.mode=eval
and run:
python main.py hydra.job.name=sample_complexity model=eqv_ae run.mode=eval eval.which_set test data=dsprites data.train_set_size=1600 run.random_seed=1
To train autoencoders on constrained data for out-of-distribution experiments, one can run (using ConvAE
as an example):
python main.py hydra.job.name=ood model=conv_ae run.mode=train data=dsprites data.train_set_size=6400 data.constrained_transform="translation_rotation"
To regenerate the visualisation in the paper, use our python script at py_scripts/out_of_distribution.py
(coming soon).
To train MONet baseline, run:
python main.py hydra.job.name=compare_to_monet model=monet run.mode=train data=multi_dsprites data.train_set_size=6400 data.batch_size=16 run.max_epochs=1000 run.random_seed=1
To train MONet-GAE-p1, run:
python main.py hydra.job.name=compare_to_monet model=eqv_monet run.mode=train data=multi_dsprites data.train_set_size=6400 data.batch_size=16 run.max_epochs=1000 run.random_seed=1
We use Hydra to specify configurations for experiments. " The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line." Our default hydra configurations can be found at conf/
.
The directory elm/
contains most of our research code. All the data loaders can be found at elm/data_loader/
, and all the models can be found at elm/model/
. Experimental results will be generated at outputs/
, organised by dates and job names. By default, Logs and checkpoints are directed to /tmp/log/
, but this can be reconfigured in conf/config.yaml
.
To ask questions about code or report issues, please directly open an issue on github. To discuss research, please email jin.xu@stats.ox.ac.uk
This repository includes code from two previous projects: GENESIS and Multi_Object_datasets. Their original licenses have been included.