/JDRL

The implementation of "Learning Single Image Defocus Deblurring with Misaligned Training Pairs".

Primary LanguagePython

JDRL

The implementation of "Learning Single Image Defocus Deblurring with Misaligned Training Pairs".

Prerequisites

  • The code has been tested with the following environment
    • Ubuntu 18.04
    • Python 3.7.9
    • PyTorch 1.7.0
    • cudatoolkit 10.0.130
    • NVIDIA TITAN RTX GPU

Datasets

Setting A

Setting B

Preparation (for flow estimation)

Download pwc-weight and put it under './pwc' folder.

Test

$ cd JDRL
  • Test on SDD:
MPRNet*: MPRNet with JDRL trained on SDD:
$ python test.py --test_path './SDD/test/' --checkpoint_path './checkpoint/mprnet-jdrl-sdd.pth' --model 'MPRNet'
UNet*: UNet with JDRL trained on SDD:
$ python test.py --test_path './SDD/test/' --checkpoint_path './checkpoint/unet-jdrl-sdd.pth' --model 'UNet'
  • Test on DPDD:
MPRNet*: MPRNet with JDRL trained on DPDD:
$ python test.py --test_path './DPDD/test/' --checkpoint_path './checkpoint/mprnet-jdrl-dpdd.pth' --model 'MPRNet'
  • Others:

ifan-jdrl-dpdd: IFAN with JDRL trained on DPDD dataset. To test this model, please refer to IFAN.

Train

  • SDD: 512*512 image patches have been provided in './SDD/train_patches'.
  • DPDD: crop the images of DPDD training set into 512*512 patches using the same settings as DPDNet.
    After getting the training patches, please organize the training dataset according to our code implementation.

Start training (on SDD, UNet*)

$ cd JDRL
$ python train.py

To apply JDRL to other models trained on DPDD dataset: (i) initialize the reblurring module: keep the pretrained deblurring model weights fixed, and train the reblurring module for several epochs. (ii) decay the learning rate by 0.1~0.01, and jointly train the reblurring module and deblurring module (your model).