Pytorch code for the paper, Improved Contrastive Divergence Training of Energy Based Models
Create a new environment and install the requirements file:
pip install -r requirements.txt
The following command trains a basic cifar10 model.
python train.py --exp=cifar10_model --step_lr=100.0 --num_steps=40 --cuda --ensembles=1 --kl_coeff=1.0 --kl=True --multiscale --self_attn --reservoir
The following command trains a CelebA model.
python train.py --dataset=celeba --exp=celeba_model --step_lr=500.0 --num_steps=40 --kl=True --gpus=8 --filter_dim=128 --multiscale --self_attn --reservoir
The following command combines models trained on CelebA together.
python celeba_combine.py
Given a list of model names, resume iterations, and conditioned values, the code composes each model together
The following command combines models trained on CIFAR-10 together.
python cifar10_combine.py
Given a list of model names, resume iterations, and conditioned values, the code composes each model together
The ebm_sandbox.py file consists of functions for evaluating EBMs (such as out-of-distribution detection). The test_inception.py contains code to evaluate generations of the model. Additional files such places_gen.py and files of the form *_gen.py can be used for qualitative generation of different datasets.
If you find the code or paper helpful, please consider citing:
@article{du2020improved,
title={Improved Contrastive Divergence Training of Energy Based Models},
author={Du, Yilun and Li, Shuang and Tenenbaum, Joshua and Mordatch, Igor},
journal={arXiv preprint arXiv:2012.01316},
year={2020}
}