/diwa

Waving Goodbye to Low-Res: A Diffusion-Wavelet Approach for Image Super-Resolution

Primary LanguagePythonApache License 2.0Apache-2.0

Waving Goodbye to Low-Res: A Diffusion-Wavelet Approach for Image Super-Resolution

This work presents a novel Diffusion-Wavelet (DiWa) approach for Single-Image Super-Resolution (SISR). It leverages the strengths of Denoising Diffusion Probabilistic Models (DDPMs) and Discrete Wavelet Transformation (DWT). By enabling DDPMs to operate in the DWT domain, our DDPM models effectively hallucinate high-frequency information for super-resolved images on the wavelet spectrum, resulting in high-quality and detailed reconstructions in image space.

Brief

This is the official implementation of Waving Goodbye to Low-Res: A Diffusion-Wavelet Approach for Image Super-Resolution (arXiv paper) in PyTorch. The repo was cleaned before uploading. Please report any bug. It complements the inofficial implementation of SR3 (GitHub).

Usage

Environment

pip install -r requirement.txt

Continue Training

# Download the pretrained model and edit [sr|sample]_[ddpm|sr3]_[resolution option].json about "resume_state":
"resume_state": [your pretrained model's path]

Data Preparation

If you don't have the data, you can prepare it by following steps:

Download the dataset and prepare it in LMDB (not DIV2K) or PNG format using script. For DIV2K, remove the "-l" parameter and also use the preprocessing step described last in this section (to extract sub-images).

# Resize to get 16×16 LR_IMGS and 128×128 HR_IMGS, then prepare 128×128 Fake SR_IMGS by bicubic interpolation
python data/prepare_data.py  --path [dataset root]  --out [output root] --size 16,128 -l

then you need to change the datasets config to your data path and image resolution:

"datasets": {
    "train": {
        "dataroot": "dataset/ffhq_16_128", // [output root] in prepare.py script
        "l_resolution": 16, // low resolution need to super_resolution
        "r_resolution": 128, // high resolution
        "datatype": "lmdb", //lmdb or img, path of img files
    },
    "val": {
        "dataroot": "dataset/celebahq_16_128", // [output root] in prepare.py script
    }
},

For DIV2K, you will need to extract the sub-images beforehand:

python data/prepare_div2k.py  --path [dataset root]  --out [output root]

Note: LMDB does not work for DIV2K.

For the test datasets:

you need to put the files into the dataset folder and run

python data/prepare_natural_tests.py

Training/Resume Training

# Use sr.py and sample.py to train the super resolution task and unconditional generation task, respectively.
# Edit json files to adjust network structure and hyperparameters
python sr.py -p train -c config/sr_sr3.json

Configurations for Training

Tasks Config File
16×16 -> 128×128 on FFHQ-CelebaHQ config/sr_wave_16_128.json
64×64 -> 512×512 on FFHQ-CelebaHQ config/sr_wave_64_512.json
48×48 -> 192×192 on DIV2K config/sr_wave_48_192.json
Ablation - baseline config/sr_wave_48_192_abl_baseline.json
Ablation - Init. Pred. only config/sr_wave_48_192_abl_pred_only.json
Ablation - DWT only config/sr_wave_48_192_abl_wave_only.json
Ablation - DiWa config/sr_wave_48_192_abl_wave+pred.json

Test/Evaluation

# Edit json to add pretrain model path and run the evaluation 
python sr.py -p val -c config/sr_sr3.json

# Quantitative evaluation alone using SSIM/PSNR/LPIPS metrics on given result root
python eval.py -p [result root]

Inference Alone

Set the image path, then run the script:

# run the script
python infer.py -c [config file]