1yuech5@mail2.sysu.edu.cn, Sun Yat-Sen University
Official PyTorch Implementations of GOUB, a diffusion bridge model that applies the Doob's h-transform to the generalized Ornstein-Uhlenbeck process. This model can address general image restoration tasks without the need for specific prior knowledge.
This code is developed with Python3, and we recommend python>=3.8 and PyTorch ==1.13.0. Install the dependencies with Anaconda and activate the environment with:
conda create --name GOUB python=3.8
conda activate GOUB
pip install -r requirements.txt
- Prepare datasets.
- Download pretrained checkpoints here.
- Modify options, including dataroot_GT, dataroot_LQ and pretrain_model_G.
- Choose a model to sample (Default: GOUB): test function in
codes/models/denoising_model.py
. python test.py -opt=options/test.yml
The Test results will be saved in \results
.
- Prepare datasets.
- Modify options, including dataroot_GT, dataroot_LQ.
python train.py -opt=options/train.yml
for single GPU.
python -m torch.distributed.launch --nproc_per_node=2 --master_port=1111 train.py -opt=options/train.yml --launcher pytorch
for multi GPUs. Attention: see Important Option Details.
The Training log will be saved in \experiments
.
We provide the interface.py for the deraining, which can generate HQ only with LQ:
- Prepare options/test.yml filling in LQ path.
python interface.py
.- The interface will be on the local server: 127.0.0.1.
Other tasks can also be written in imitation.
dataroot_GT
: Ground Truth (High-Quality) data path.dataroot_LQ
: Low-Quality data path.pretrain_model_G
: Pretraind model path.GT_size, LQ_size
: Size of the data cropped during training.niter
: Total training iterations.val_freq
: Frequency of validation during training.save_checkpoint_freq
: Frequency of saving checkpoint during training.gpu_ids
: In multi-GPU training, GPU ids are separated by commas in multi-gpu training.batch_size
: In multi-GPU training, must satisfy relation: batch_size/num_gpu>1.
We provid a brief guidelines for commputing FID of two set of images:
- Install FID library:
pip install pytorch-fid
. - Commpute FID:
python -m pytorch_fid GT_images_file_path generated_images_file_path --batch-size 1
if all the images are the same size, you can remove--batch-size 1
to accelerate commputing.