/Discrete-Probability-Flow

[NeurIPS 2023] Formulating Discrete Probability Flow Through Optimal Transport

Primary LanguagePython

Discrete-Probability-Flow

The source code for our paper "Formulating Discrete Probability Flow Through Optimal Transport", Pengze Zhang*, Hubery Yin*, Chen Li, Xiaohua Xie, NeurIPS 2023. Video: [English]

DPF

Abstract

Continuous diffusion models are commonly acknowledged to display a deterministic probability flow, whereas discrete diffusion models do not. In this paper, we aim to establish the fundamental theory for the probability flow of discrete diffusion models. Specifically, we first prove that the continuous probability flow is the Monge optimal transport map under certain conditions, and also present an equivalent evidence for discrete cases. In view of these findings, we are then able to define the discrete probability flow in line with the principles of optimal transport. Finally, drawing upon our newly established definitions, we propose a novel sampling method that surpasses previous discrete diffusion models in its ability to generate more certain outcomes. Extensive experiments on the synthetic toy dataset and the CIFAR-10 dataset have validated the effectiveness of our proposed discrete probability flow.

🔥News🔥

[2024.1.8] Our discrete probability flow can achieve the controllable generation of interpolation in the latent space.

Discrete Probability Flow on SDDM (Toy dataset)

1) Get start

  • Python 3.9.0
  • jax 0.4.8
  • jaxlib 0.4.7
  • CUDA 12.1
  • NVIDIA A100 40GB PCIe

Open the directory for sddm.

cd sddm

2) Generate your synthetic dataset

The synthetic data can be divided into 7 categories: 2spirals, 8gaussians, checkerboard, circles, moons, pinwheel, swissroll. You can set 'data_name' for selection.

Binary graycode

data_name=XXX bash sddm/synthetic/data/run_binary_data_dump.sh

Base5 code

data_name=XXX bash sddm/synthetic/data/run_base5_data_dump.sh

Base10 code

data_name=XXX bash sddm/synthetic/data/run_base10_data_dump.sh

You can also directly download our synthetic dataset into ./sddm/data

3) Train on the synthetic dataset

Binary graycode

data_name=XXX config_name=binary_graycode bash sddm/synthetic/train_binary_graycode.sh

Base5 code

data_name=XXX config_name=base5_code bash ./sddm/synthetic/train_base5_code.sh

Base10 code

data_name=XXX config_name=base10_code bash ./sddm/synthetic/train_base10_code.sh

You can also directly download our pre-trained model into ./sddm/results

4) Test MMD

Please switch 'sampler_type' in 'sddm/synthetic/config/*.py' to choose lbjf or dpf sampling.

Binary graycode

data_name=XXX config_name=binary_graycode bash sddm/synthetic/binary_test_mmd.sh 

Base5 code

data_name=XXX config_name=base5_code bash sddm/synthetic/base5_test_mmd.sh

Base10 code

data_name=XXX config_name=base10_code bash sddm/synthetic/base10_test_mmd.sh

5) Test CSD

Please switch 'sampler_type' in 'sddm/synthetic/config/*.py' to choose lbjf or dpf sampling.

Binary graycode

data_name=XXX config_name=binary_graycode bash sddm/synthetic/binary_test_std.sh 

Base5 code

data_name=XXX config_name=base5_code bash sddm/synthetic/base5_test_std.sh

Base10 code

data_name=XXX config_name=base10_code bash sddm/synthetic/base10_test_std.sh

Discrete Probability Flow on TauLDR (Cifar10)

1) Get start

  • Python 3.9.7
  • pytorch 1.12.1
  • torchvision 0.13.1
  • CUDA 11.3
  • NVIDIA A100 40GB PCIe

Open the directory for TauLDR.

cd TauLDR

2) Prepare the pre-trained model

The model is provided by TauLDR. You can directly download the model into ./TauLDR/models/cifar10

3) Generate samples for evaluation

Please download the x_T into ./TauLDR for reproduction.

Switch 'DPF_type' in 'TauLDR/config/eval/cifar10.py' to 0 / 1 to choose TauLDR / DPF sampling.

python generate_test_certainty_data.py

4) Test CSD

python test_std.py --DPF_type X

5) Test class-std

Please download the pretrained Cifar10 classifier, and put it in ./TauLDR

python test_class_std.py --DPF_type X

6) Test class-entropy

python test_entropy.py --DPF_type X

7) Visualization

Please download the x_T into ./TauLDR for reproduction our Figure 12.

Switch 'DPF_type' in 'TauLDR/config/eval/cifar10.py' to 0 / 1 to choose TauLDR / DPF sampling.

python visualization.py

🔥Application: interpolation in the latent space (Celeb)

We trained TauLDR on the Celeb dataset. You can directly download our pre-trained model into ./TauLDR/models/celeb_128.

Switch 'DPF_type' in 'TauLDR/config/eval/celeb.py' to 0 / 1 to choose TauLDR / DPF sampling.

python visualization_celeb.py

Citation

@inproceedings{
zhang2023formulating,
title={Formulating Discrete Probability Flow Through Optimal Transport},
author={Pengze Zhang and Hubery Yin and Chen Li and Xiaohua Xie},
booktitle={Advances in Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=I9GNrInbdf}
}

Acknowledgement

We build our project based on SDDM and TauLDR. We thank them for their wonderful work and code release.