Code release for "Guidance with Spherical Gaussian Constraint for Conditional Diffusion(DSG)".
The code implementation is based on https://github.com/DPS2022/diffusion-posterior-sampling. This version only includes the Linear Inverse Problems; the code for Non-linear Problems is coming soon.
Install dependencies:
conda create -n DSG python=3.8
conda activate DSG
pip install -r requirements.txt
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
Download checkpoint ffhq_10m.pt
or imagenet256.pt
from https://github.com/DPS2022/diffusion-posterior-sampling and paste it to ./models/
.
You could modify the parameters following the comment in generate.sh
and run it.
bash generate.sh
You could modify the parameters following the comment in run_DSG.sh
and run it using the hyperparameter in Table 3 in the Appendix of paper.
The results are shown in ./total_results_DSG_DDIM"$DDIM"/DSG_interval_${interval}_ guidance_${guidance_scale}/{TASK}/recon
.
bash run_DSG.sh
-
Change the data_root in
generate.sh
andrun_DSG.sh
. -
Change the self.fpaths in
/data/dataloader.py
e.g. if image is
jpg
format, change it to:
self.fpaths = sorted(glob(root + '/*.jpg', recursive=False))
- Change the model config in
run_DSG.sh
.
If you find our work useful in your research, please consider citing
@inproceedings{
yang2023dsg,
title={Guidance with Spherical Gaussian Constraint for Conditional Diffusion},
author={Lingxiao Yang and Shutong Ding and Yifan Cai and Jingyi Yu and Jingya Wang and Ye Shi},
booktitle={International Conference on Machine Learning},
year={2024}
}