/PnP-Flow

Primary LanguageJupyter Notebook

PnP-Flow

This GitHub repository contains the code for PnP-Flow, a method combining PnP methods with Flow Matching pretrained models for solving image restoration problems. Try out the demo!

1. Getting started

1.1. Requirements

  • torch 1.13.1 (or later)
  • torchvision
  • tqdm
  • numpy
  • pandas
  • pyyaml
  • scipy
  • torchdiffeq
  • deepinv

1.2. Download datasets

To download the datasets, we follow the guidelines of https://github.com/clovaai/stargan-v2. The downloaded datasets should be placed in the folder data/ the following way:

.
├── ...
├── data
│   ├── mnist
│   ├── celeba
└── ...

The dataset AFHQ-Cat doesn't have a validation split. To create the same split as we did for our experiments, run scripts/afhq_validation_images.bash.

1.3. Download pre-trained models

We provide the following pre-trained OT Flow Matching models (U-Net):

And the denoisers for the PnP-GS method:

2. Training

You can also use the code to train your own OT Flow Matching model.

You can modify the config options directly in the main_config.yaml file located in config/. Alternatively, config keys can be given as options directly in the command line.

For example, to train the generative flow matching model (here, the U-net is the velocity) on CelebA, with a Gaussian latent distribution, run:

python main.py --opts dataset celeba train True eval False batch_size 128 num_epoch 100

At each 5 epochs, the model is saved in ./model/celeba/gaussian/ot. Generated samples are saved in ./results/celeba/gaussian/ot.

Computing generative model scores

After the training, the final model is loaded and can be used for generating samples / solving inverse problems. You can compute the full FID (based on 50000 generated samples), the Vendi score, and the Slice Wasserstein score running

python main.py --opts dataset mnist train False eval True compute_metrics True solve_inverse_problem False

3. Solving inverse problems

The available inverse problems are:

  • Denoising --> set problem: 'denoising'
  • Gaussian deblurring --> set problem: 'gaussian_deblurring'
  • Super-resolution --> set problem: 'superresolution'
  • Box inpainting --> set problem: 'inpainting'
  • Random inpainting --> set problem: 'random_inpainting'
  • Free-form inpainting --> set problem: 'paintbrush_inpainting'

The parameters of the inverse problems (e.g., noise level) can be adjusted manually in the main.py file.

The available methods are

  • pnp_flow (our method)
  • ot_ode (from this paper)
  • d_flow (from this paper)
  • flow_priors (from this paper)
  • pnp_diff (from this paper)
  • pnp_gs (from this paper)

3.1. Finding the optimal parameters on the validation set

The optimal parameters can tuned running

python bash scripts/script_val.sh

You can also use the optimal values we found, as reported in the Appendix of the paper, and input them into the configuration files of the methods.

3.2. Evaluation on the test set

You can either directely run

python main.py --opts dataset celeba train False eval True problem inpainting method pnp_flow

or the use the bash file scripts/script_val.sh.

Visual results will be saved in results/celeba/gaussian/inpainting.

Acknowledgements

This repository builds upon the following publicly available codes: