Accompanying code for the paper 'Improved sampling via learned diffusions' [
ICLR'24
,BibTeX
] and 'An optimal control perspective on diffusion-based generative modeling' [SBM@NeurIPS'22
,BibTeX
].
This repo contains various methods (DIS, Bridge, DDS, PIS) to sample from unnormalized densities by learning to control stochastic differential equations (SDEs). Given an unnormalized target density
-
Repo: First clone the repo:
git clone git@github.com:juliusberner/sde_sampler.git cd sde_sampler
-
Environment: We recommend using Conda to set up the codebase:
conda create -n sde_sampler python==3.9 pip --yes conda activate sde_sampler
-
GPU: If you have a GPU, check your CUDA version using
nvidia-smi
and install compatiblecuda
(forpykeops
) andtorch
/torchvision
packages using the PyTorch install guide (see here for previous versions). For instance, if your CUDA version is>=11.7
, you could run:conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 cuda-minimal-build=11.7 -c pytorch -c nvidia --yes
-
CPU: In case you do not have a GPU, you can replace the previous step with:
conda install pytorch torchvision cpuonly -c pytorch --yes
-
Packages: Now, you can install the
sde_sampler
package:pip install -e .
-
Wandb: Finally, login to your wandb account:
wandb login
You can also omit this step and add the
wandb.mode=disabled
command line arg to your runs. -
Additional requirements: See targets, for requirements and steps which are only needed for specific targets.
-
Test: To test the
pykeops
andtorch
installations on a machine with GPU:python -c "import torch; print(torch.cuda.is_available())" python -c "import pykeops; pykeops.test_torch_bindings()"
Sample from a shifted double well using the DIS solver and the log-variance divergence:
python scripts/main.py target=dw_shift solver=basic_dis loss.method=lv
- This outputs a link to the metrics and plots:
wandb: 🚀 View run at <link>
- You also find the metrics/plots/checkpoints locally at
logs/<Y-m-d>/<H-M-S>
.
We use:
hydra
for config management and experiment execution.wandb
for experiment tracking and hyperparameter sweeps.
All configs can be found in the folder conf
. You can just adapt the existing .yaml
configs or add new ones. Configs are hierarchical and will be dynamically created by composition, see the Hydra intro or run
python scripts/main.py --help
Most important:
-
Configs are based on
conf/base.yaml
and the solver-specific config inconf/solver
. -
Config defaults can be overridden on the command line analogously as specified in the
defaults
list in each config. For instance:conf/base.yaml
contains the defaultsolver: dis
. To select the PIS solver inconf/solver/basic_pis.yaml
, we can add the command line argsolver=basic_pis
.conf/solver/basic_pis.yaml
contains the default/model@generative_ctrl: score
. To use the model inconf/model/clipped.yaml
, we can add the command line argmodel@generative_ctrl=clipped
.- To add a new default use
+
. To add the learning rate schedulerconf/lr_scheduler/multi_step.yaml
, use the command line arg+lr_scheduler=multi_step
.
-
Each individual entry of the config can be overriden on the command line using the nested config keys separated by dots, e.g.,
generative_ctrl.base_model.channels=32
. -
You can also change the
wandb
setting inconf/base.yaml
. To change the project totest
, addwandb.project=test
to the args.
Combining the examples above:
python scripts/main.py solver=basic_pis model@generative_ctrl=clipped generative_ctrl.base_model.channels=32 +lr_scheduler=multi_step wandb.project=test
Run the experiment above on the shifted double well as well as a Gaussian mixture model for both the log-variance and the KL divergence (using the hydra multi-run flag -m/--multirun
):
python scripts/main.py -m +launcher=<launcher> target=dw_shift,gmm solver=basic_dis loss.method=kl,lv
Set <launcher>
to
joblib
if you work on a local machine. This uses thejoblib
library.slurm
if you work on a Slurm cluster. You might need to adapt the default configs inconf/launcher/slurm.yaml
to your cluster.
You can find an examplary sweep in conf/sweeps
.
- Invoke the sweep:
wandb sweep conf/sweeps/<sweep_name>.yaml
- Start agents as described in the output of the previous command. For slurm, you can use
SWEEP_ID=<wandb_entity>/<wandb_project>/<sweep_id> sbatch -a 0-<num agents> bin/slurm_sweep.sh
You can resume a run by specifying its wandb id wandb.id=<wandb-id>
. With the default settings, you can run
python scripts/main.py --config-name=setup wandb.id=<wandb-id>
and the previous configuration and latest ckeckpoint will be automatically downloaded from wandb.
When using the same log directory, the wandb id is inferred automatically (which is useful for slurm preemption). You can also add the log directory manually via the command line arg hydra.run.dir=<logs/Y-m-d/H-M-S>
.
For more flexibility (e.g, adapting configs and command line args), you can also specify the checkpoint ckpt_file=<path-to-ckpt-file>
directly.
Our predefined solvers in conf/solver
include the following methods:
-
Time-Reversed Diffusion Sampler (DIS)
solver=dis
(see our paper) -
Denoising Diffusion Sampler (DDS)
solver=dds
(see the DDS repo; note that this is not the original implementation since we are using another parametrization and integrator) -
Path Integral Sampler (PIS)
solver=pis
(see the PIS repo) -
Bridge sampler (Bridge)
solver=bridge
(see our paper; this can be viewed as generalized Schrödinger bridge) -
Unadjusted Langevin algorithm (ULA)
solver=langevin
The configs with prefix basic_
in conf/solver
are simplified and can easily be adapted to specific targets or settings.
For the first four solvers, you can use either the KL divergence loss.method=kl
or the log-variance divergence loss.method=lv
(see our paper).
For the first three solvers, the log-variance divergence can also be computed over trajectories with the same initial point by using loss.method=lv_traj
.
In most of our experiments, the log-variance divergence led to improved performance.
Our predefined targets in conf/target
include the following distributions:
-
Gaussian Mixture Model
target=gmm
(2d, see our paper and the PIS repo) -
Multi/Double-Well
target=dw_shift
,target=mw
, andtarget=mw_50d
(1d/5d/50d, see our paper) -
Gaussian
target=gauss_shift
(1d) -
Image:
target=img
(2d, see the SNF repo): For better visualization of the image density, we suggest to useeval_batch_size=500000
. -
Rings
target=rings
(2d, see the PIS repo) -
Rosenbrock:
target=rosenbrock
(15d, to test samplers for global optimization, see arxiv:2111.00402 and wikipedia) -
Alanine Dipeptide
target=aladip
(60d, see the FAB repo): Install the following additional requirementsconda install -c conda-forge openmm openmmtools=0.23.1 --yes pip install boltzgen@git+https://github.com/VincentStimper/boltzmann-generators.git@v1.0
Then, download the evaluation data using
bash bin/download_aladip.sh
-
NICE
target=nice
(196d, see arxiv:2208.07698 and https://github.com/fmu2/NICE): First, train the NICE model on MNIST usingpython scripts/train_nice.py
or, on a slurm cluster, using
sbatch bin/slurm_train_nice.sh
This saves a checkpoint in
data/nice.pt
, which is then automatically used. -
Log Gaussian Cox Process
target=cox
(1600d, see the PIS repo)
If you use parts of this codebase in your research, please use the following BibTeX entries.
@inproceedings{berner2022optimal,
title={An optimal control perspective on diffusion-based generative modeling},
author={Berner, Julius and Richter, Lorenz and Ullrich, Karen},
booktitle={NeurIPS 2022 Workshop on Score-Based Methods},
year={2022}
}
@inproceedings{richter2024improved,
title={Improved sampling via learned diffusions},
author={Richter, Lorenz and Berner, Julius},
booktitle={International Conference on Learning Representations},
year={2024}
}
The majority of the project is licensed under MIT. Portions of the project are adapted from other repositories (as mentioned in the code):
- https://github.com/fmu2/NICE is also licensed under MIT,
- https://github.com/yang-song/score_sde is licensed under Apache-2.0,
- https://github.com/noegroup/stochastic_normalizing_flows is licensed under BSD-3-Clause,
- the repositories https://github.com/lollcat/fab-torch, https://github.com/qsh-zh/pis, and https://github.com/fwilliams/scalable-pytorch-sinkhorn do not provide licenses.