/flow_mcmc

Official repository for the preprint "On Sampling with Approximate Transport Maps" in PyTorch.

Primary LanguagePython

Flow MCMC

This is the official repository for the preprint "On Sampling with Approximate Transport Maps" (Arxiv) in PyTorch.

Installation

Install dependencies with pip install -r requirements.txt. If you want to rerun the Alanine Dipeptide experiments, use conda and install the following packages conda install -c conda-forge openmm openmmtools mdtraj and also pip install git+https://github.com/VincentStimper/boltzmann-generators.git

MCMC Samplers

Flow MCMC provide many popular MCMC samplers under a common API

  • Metropolis-Adjusted Langevin Algorithm (MALA) with mcmc.mala.MALA
  • Hamiltonian Monte Carlo (HMC) (Neal et al., 2011) with mcmc.hmc.HMC
  • Random Walk Metropolis Hastings (RWMH) with mcmc.rwhm.RWHM
  • Elliptical Slice Sampling (ESS) (Murray et al., 2010) with mcmc.ess.ESS
  • Independent Metropolis-Hastings (IMH) with mcmc.imh.IndependentMetropolisHastings
  • Iterated Sampling Importance Resampling (i-SIR) (Andrieu et al., 2010) with mcmc.isir.iSIR
  • Pyro wrappers for HMC and NUTS (Bingham et al., 2019) with mcmc.pyro_mcmc.HMC and mcmc.pyro_mcmc.NUTS

Sampling works by calling

sampler.sample(x_s_t_0, n_steps, target, temp=1.0, warmup_steps=0, verbose=False)

Where

  • x_s_t_0 (tensor of shape (batch_size, dim) ) is the first sample of the chain
  • n_steps (int) is the length of the chain
  • target (callable) is the log-likelihood of the target distribution (it must support batched inputs)
  • temp (float) is the temperature factor of the log-likelihood
  • warmup_steps (int) is the number of burn-in steps (they will be wasted)
  • verbose (bool) displays a progress bar during sampling

Note that unlike many MCMC samplers in PyTorch, all the MCMC samplers support sampling multiple chains in parallel. Each sampler collects diagnostics (acceptance rates, ...) which can be collected using sampler.get_diagnostics(diag_name).

You can also use normalizing flows to enhance your sampler in two ways

  • Using flow-MCMC algorithms (i.e., using the flow as global proposals)
  • Using neutra-MCMC algorithms (i.e., using the flow as a reparametrization map) (Parno & Marzouk, 2018, Hoffman et al., 2019) by wrapping the sampler with mcmc.neutra.NeuTra(inner_sampler, flow)

Importance Sampling is also available at mcmc.classic_is.IS which the same API as MCMCs (batch_size and n_steps are ignored) and can be enhanced by a flow.

We also provide a way to perform adaptive learning of normalizing flows with mcmc.learn_flow.LearnMCMC (Gabrie et al. , 2022) .

We provide an implementation of RealNVP (Dinh et al., 2016) based on marylou-gabrie/adapt-flow-ergo's implementation as well as a wrapper to flows from VincentStimper/normalizing-flows.

"On Sampling with Approximate Transport Maps"

Here we explain how to rerun the experiments presented in the paper. Note that the output paths are defined on top of the configuration files in configs/flow_approx/.

Synthetic case studies

All the experiments from the synthetic case studies can be rerun using the following commands

python experiments/flow_approx/gaussians_three_flows.py configs/flow_approx/gaussians_three_flows.yaml --seed {INSERT_SEED}
python experiments/flow_approx/funnel.py configs/flow_approx/funnel.yaml --seed {INSERT_SEED}
python experiments/flow_approx/gaussian_mixture.py configs/flow_approx/gaussians_mixture.yaml --seed {INSERT_SEED}
python experiments/flow_approx/banana.py {OUTPUT_PATH}/backward_dim{DIMENSION}.pkl --loss_type backward_kl --dim {DIMENSION} --seed {SEED}

The hyper-parameter grid search can be rerun by using the *_debug.yaml configs. The flows for the mixture of Gaussians can be re-trained using

python experiments/flow_approx/gaussian_mixture_flow.py --dim {DIMENSION} --checkpoint_path {SAVE_PATH}/dim_{DIMENSION}/

Benchmarks on real tasks

Alanine Dipeptide

The flow (also available in experiments/flow_approx/models/aldp/flow_aldp.pt) can be retrained using the procedure described in lollcat/fab-torch (Midgley et al., 2022). Sampling can be performed using

python experiments/flow_approx/aldp.py configs/flow_approx/aldp.yaml --seed {SEED} --save_samples

The data used for the ground truth are available on authors' Zenodo.

Logistic Regression

The flow for the logistic regression experiment can be obtained by running

python experiments/flow_approx/logistic_regression_flow.py --save_path {OUTPUT_PATH} --neutra_flow

and sampling can be done with

python experiments/flow_approx/logistic_regression.py configs/flow_approx/logistic_regression.yaml --seed {SEED} --neutra_flow

Note that you will need ground truth samples obtained using NUTS by running

python experiments/flow_approx/logistic_regression_gt.py --save_path {OUTPUT_PATH}

Phi Four

The flows for the Phi Four experiment can be obtained by running

python experiments/flow_approx/phi_four_parameters.py configs/flow_approx/phi_four_parameters/global_{DIMENSION}.yaml configs/flow_approx/phi_four_parameters/best_flows_{DIMENSION}.yaml --mala_sampler 

and sampling can be done with

python experiments/flow_approx/phi_four.py configs/flow_approx/phi_four.yaml --save_samples --seed {SEED}

Appendix

The flows for the figure 8 can be retrained using

python experiments/flow_approx/many_flows_two_moons.py --load_path {OUTPUT_PATH} --seed {SEED}

🏗️ TODO

  • Fix mcmc.hmc.HMC : right now the warmup phase is broken
  • Allow learning a preconditioning matrix for mcmc.mala.MALA