/DMSB

Official implementation of Deep Momentum Schrödinger Bridge

Primary LanguagePythonMIT LicenseMIT

(NeurIPS2023) DMSB: Deep Momentum Multi-Marginal Schrödinger Bridge [LINK]

Official PyTorch implementation of the paper "Deep Momentum Multi-Marginal Schrödinger Bridge (DMSB)" which introduces a new class of trajectory inference models that extend SB models to momentum dynamcis and multi-marginal case.

Connection with Vanilla Schrödinger Bridge

Example GIF

Toy Examples

Tasks (--problem-name) Results
Mixture Gaussians (gmm)

drawing

Semicircle (semicircle)

drawing

Petal (Petal)

drawing

100-Dim Single Cell RNA sequence (RNAsc)

drawing

If you find this library useful, please cite ⬇️
@article{chen2023deep,
  title={Deep Momentum Multi-Marginal Schr$\backslash$" odinger Bridge},
  author={Chen, Tianrong and Liu, Guan-Horng and Tao, Molei and Theodorou, Evangelos A},
  journal={arXiv preprint arXiv:2303.01751},
  year={2023}
}

Installation

(Environment may have conflict with cuda version... I am currently fixing it... but it should work for most of cuda...)This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1). First, install the dependencies with Anaconda and activate the environment DMSB with

conda env create --file requirements.yaml python=3.8
conda activate DMSB

Download the RNA-seq daaset from this repo, and put it under ./data/RNAsc/ProcessedData/.

Reproducing the result in the paper


We provide the checkpoint and the code for training from scratch for all the dataset reported in the paper.

GMM

python main.py --problem-name gmm --dir reproduce/gmm --log-tb --gpu 1

Memo: The results in the paper sould be reproduced by around 6 stage of Bregman Iteration.

Petal

python main.py --problem-name petal --dir reproduce/petal --log-tb

Memo: The results in the paper sould be reproduced by around 17 stage of Bregman Iteration.

RNAsc

python main.py --problem-name RNAsc --dir reproduce/RNA --log-tb  --num-itr 2000
python main.py --problem-name RNAsc --dir reproduce/RNA-loo1 --log-tb  --use-amp --num-itr 2000 --LOO 1
python main.py --problem-name RNAsc --dir reproduce/RNA-loo2 --log-tb  --use-amp --num-itr 2000 --LOO 2
python main.py --problem-name RNAsc --dir reproduce/RNA-loo3 --log-tb  --use-amp --num-itr 2000 --LOO 3

Where Can I find the results?

The visualization results are saved in the folder /results.

The numerical value are saved in the tensorboard and event file are saved the folder /runs,

The checkpoints are saved in the folder /checkpoint, and you can reload the checkpoint by:

python main.py --problem-name [problem-name] --dir [your/dir/name/for/current/run] --log-tb  --load [dir/to/checkpoints/]

The numerical results for all metrics will be displayed in the terminal as well.