/sampling_diffusion

Improved sampling via learned diffusions (ICLR2024) and an optimal control perspective on diffusion-based generative modeling (SBM@NeurIPS2022)

Primary LanguagePythonMIT LicenseMIT

Sampling via learned diffusions: sde_sampler

Accompanying code for the paper 'Improved sampling via learned diffusions' [ICLR'24,BibTeX] and 'An optimal control perspective on diffusion-based generative modeling' [SBM@NeurIPS'22,BibTeX].

This repo contains various methods (DIS, Bridge, DDS, PIS) to sample from unnormalized densities by learning to control stochastic differential equations (SDEs). Given an unnormalized target density $\rho=Zp_{\mathrm{target}}$, where $Z = \int \rho(x) \mathrm{d}x$, we optimize a neural network $u$ to control the SDE $$\mathrm{d}X^u_t = (\mu + \sigma u)(X^u_t,t) \mathrm{d}t + \sigma(t) \mathrm{d}W_t, \quad X^u_0 \sim p_{\mathrm{prior}},$$ such that $X^u_T \sim p_{\mathrm{target}}$. Then one can sample from the prior $p_{\mathrm{prior}}$ and simulate the SDE $X^u$ to obtain samples from $p_{\mathrm{target}}$.

Installation

  • Repo: First clone the repo:

    git clone git@github.com:juliusberner/sde_sampler.git
    cd sde_sampler
    
  • Environment: We recommend using Conda to set up the codebase:

    conda create -n sde_sampler python==3.9 pip --yes
    conda activate sde_sampler
    
  • GPU: If you have a GPU, check your CUDA version using nvidia-smi and install compatible cuda (for pykeops) and torch/torchvision packages using the PyTorch install guide (see here for previous versions). For instance, if your CUDA version is >=11.7, you could run:

    conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 cuda-minimal-build=11.7 -c pytorch -c nvidia --yes
    
  • CPU: In case you do not have a GPU, you can replace the previous step with:

    conda install pytorch torchvision cpuonly -c pytorch --yes
    
  • Packages: Now, you can install the sde_sampler package:

    pip install -e .
    
  • Wandb: Finally, login to your wandb account:

    wandb login
    

    You can also omit this step and add the wandb.mode=disabled command line arg to your runs.

  • Additional requirements: See targets, for requirements and steps which are only needed for specific targets.

  • Test: To test the pykeops and torch installations on a machine with GPU:

    python -c "import torch; print(torch.cuda.is_available())"
    python -c "import pykeops; pykeops.test_torch_bindings()"
    

Quick Start

Sample from a shifted double well using the DIS solver and the log-variance divergence:

python scripts/main.py target=dw_shift solver=basic_dis loss.method=lv
  • This outputs a link to the metrics and plots: wandb: 🚀 View run at <link>
  • You also find the metrics/plots/checkpoints locally at logs/<Y-m-d>/<H-M-S>.

How-To

Setup

We use:

Configs

All configs can be found in the folder conf. You can just adapt the existing .yaml configs or add new ones. Configs are hierarchical and will be dynamically created by composition, see the Hydra intro or run

python scripts/main.py --help

Most important:

  1. Configs are based on conf/base.yaml and the solver-specific config in conf/solver.

  2. Config defaults can be overridden on the command line analogously as specified in the defaults list in each config. For instance:

  3. Each individual entry of the config can be overriden on the command line using the nested config keys separated by dots, e.g., generative_ctrl.base_model.channels=32.

  4. You can also change the wandb setting in conf/base.yaml. To change the project to test, add wandb.project=test to the args.

Combining the examples above:

python scripts/main.py solver=basic_pis model@generative_ctrl=clipped generative_ctrl.base_model.channels=32 +lr_scheduler=multi_step wandb.project=test

Multi-run & Slurm

Run the experiment above on the shifted double well as well as a Gaussian mixture model for both the log-variance and the KL divergence (using the hydra multi-run flag -m/--multirun):

python scripts/main.py -m +launcher=<launcher> target=dw_shift,gmm solver=basic_dis loss.method=kl,lv

Set <launcher> to

Wandb sweeps

You can find an examplary sweep in conf/sweeps.

  1. Invoke the sweep:
wandb sweep conf/sweeps/<sweep_name>.yaml
  1. Start agents as described in the output of the previous command. For slurm, you can use
SWEEP_ID=<wandb_entity>/<wandb_project>/<sweep_id> sbatch -a 0-<num agents> bin/slurm_sweep.sh

Resuming

You can resume a run by specifying its wandb id wandb.id=<wandb-id>. With the default settings, you can run

python scripts/main.py --config-name=setup wandb.id=<wandb-id>

and the previous configuration and latest ckeckpoint will be automatically downloaded from wandb. When using the same log directory, the wandb id is inferred automatically (which is useful for slurm preemption). You can also add the log directory manually via the command line arg hydra.run.dir=<logs/Y-m-d/H-M-S>. For more flexibility (e.g, adapting configs and command line args), you can also specify the checkpoint ckpt_file=<path-to-ckpt-file> directly.

Experiments

Solvers

Our predefined solvers in conf/solver include the following methods:

  1. Time-Reversed Diffusion Sampler (DIS) solver=dis (see our paper)

  2. Denoising Diffusion Sampler (DDS) solver=dds (see the DDS repo; note that this is not the original implementation since we are using another parametrization and integrator)

  3. Path Integral Sampler (PIS) solver=pis (see the PIS repo)

  4. Bridge sampler (Bridge) solver=bridge (see our paper; this can be viewed as generalized Schrödinger bridge)

  5. Unadjusted Langevin algorithm (ULA) solver=langevin

The configs with prefix basic_ in conf/solver are simplified and can easily be adapted to specific targets or settings.

For the first four solvers, you can use either the KL divergence loss.method=kl or the log-variance divergence loss.method=lv (see our paper). For the first three solvers, the log-variance divergence can also be computed over trajectories with the same initial point by using loss.method=lv_traj. In most of our experiments, the log-variance divergence led to improved performance.

Targets

Our predefined targets in conf/target include the following distributions:

  • Funnel target=funnel (10d, see our paper and the PIS repo)

  • Gaussian Mixture Model target=gmm (2d, see our paper and the PIS repo)

  • Multi/Double-Well target=dw_shift, target=mw, and target=mw_50d (1d/5d/50d, see our paper)

  • Gaussian target=gauss_shift (1d)

  • Image: target=img (2d, see the SNF repo): For better visualization of the image density, we suggest to use eval_batch_size=500000.

  • Rings target=rings (2d, see the PIS repo)

  • Rosenbrock: target=rosenbrock (15d, to test samplers for global optimization, see arxiv:2111.00402 and wikipedia)

  • Alanine Dipeptide target=aladip (60d, see the FAB repo): Install the following additional requirements

    conda install -c conda-forge openmm openmmtools=0.23.1 --yes
    pip install boltzgen@git+https://github.com/VincentStimper/boltzmann-generators.git@v1.0
    

    Then, download the evaluation data using

    bash bin/download_aladip.sh
    
  • NICE target=nice (196d, see arxiv:2208.07698 and https://github.com/fmu2/NICE): First, train the NICE model on MNIST using

    python scripts/train_nice.py
    

    or, on a slurm cluster, using

    sbatch bin/slurm_train_nice.sh
    

    This saves a checkpoint in data/nice.pt, which is then automatically used.

  • Log Gaussian Cox Process target=cox (1600d, see the PIS repo)

References

If you use parts of this codebase in your research, please use the following BibTeX entries.

@inproceedings{berner2022optimal,
  title={An optimal control perspective on diffusion-based generative modeling},
  author={Berner, Julius and Richter, Lorenz and Ullrich, Karen},
  booktitle={NeurIPS 2022 Workshop on Score-Based Methods},
  year={2022}
}

@inproceedings{richter2024improved,
  title={Improved sampling via learned diffusions},
  author={Richter, Lorenz and Berner, Julius},
  booktitle={International Conference on Learning Representations},
  year={2024}
}

License

The majority of the project is licensed under MIT. Portions of the project are adapted from other repositories (as mentioned in the code):