Code for reproducing the results in "Gradient Estimation with Discrete Stein Operators"
https://arxiv.org/abs/2202.09497
NeurIPS 2022 Outstanding Paper Award
tensorflow >= 2.5.0
tensorflow-datasets >= 4.2.0
tensorflow-probability >= 0.12.2
scipy >= 1.6.3
absl >= 0.12.0
pandas >= 1.2.4
numpy >= 1.19.5
tqdm >= 4.60.0
Running VAE experiments:
python experiment_launcher_singlelayer.py --dataset={dataset} --genmo_lr={lr} --infnet_lr={lr} --encoder_type=nonlinear --grad_type={grad_type} --K={K} --D=200 --seed={seed}
dataset
:- dynamically binarized:
mnist
,fashion_mnist
,omniglot
. - non-binarized:
continuous_mnist
,continuous_fashion
,continuous_omniglot
.
- dynamically binarized:
lr
:- dynamically binarized:
1e-3
formnist
andomniglot
,3e-4
forfashion_mnist
. - non-binarized:
1e-4
.
- dynamically binarized:
grad_type
:- REINFORCE leave-one-out:
reinforce_loo
- DisARM (Dong et al., 2020):
disarm
- Double CV (Titsias & Shi, 2021):
double_cv
- RELAX (Grathwohl et al., 2017):
relax
(not affected byK
, always using 3 evaluations off
) - ARMS (Dimitriev & Zhou, 2020):
arms
- RODEO (ours):
K=2
:discrete_stein_avg
K>2
:discrete_stein_output_avg
- REINFORCE leave-one-out:
K
: number of samples used, equivalent to number of evaluations off
in gradient estimators except RELAX.seed
: 1-5.
Running hierarchical VAE experiments:
python experiment_launcher_multilayer.py --dataset={dataset} --genmo_lr={lr} --infnet_lr={lr} --grad_type={grad_type} --K=2 --seed={seed}
To cite this work, please use
@inproceedings{NEURIPS2022_a5a5b0ff,
author = {Shi, Jiaxin and Zhou, Yuhao and Hwang, Jessica and Titsias, Michalis and Mackey, Lester},
booktitle = {Advances in Neural Information Processing Systems},
editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
pages = {25829--25841},
publisher = {Curran Associates, Inc.},
title = {Gradient Estimation with Discrete Stein Operators},
url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/a5a5b0ff87c59172a13342d428b1e033-Paper-Conference.pdf},
volume = {35},
year = {2022}
}
The code is modified from Double CV (https://github.com/thjashin/double-cv), originally based on DisARM (https://github.com/google-research/google-research/tree/master/disarm/binary) and ARMS (https://github.com/alekdimi/arms)