pip install -r requirements.txt
Given the dataset, please prepare the images paths in a folder named by the dataset with the following folder strcuture.
flist/dataset_name
├── train.flist # paths of training images
├── valid.flist # paths of validation images
└── test.flist # paths of testing images
In this work, we use CelebA-HQ (Download availbale here), Places2 (Download availbale here), ParisStreet View (need author's permission to download)
ImageNet K-means Cluster: The kmeans_centers.npy
is downloaded from image-gpt, it's used to quantitize the low-resolution images.
- Download pre-trained models:
- Put the pre-trained model under the checkpoints folder, e.g.
checkpoints
├── celebahq_bat_pretrain
├── latest_net_G.pth
- Prepare the input images and masks to test.
python bat_sample.py --num_sample [1] --tran_model [bat name] --up_model [upsampler name] --input_dir [dir of input] --mask_dir [dir of mask] --save_dir [dir to save results]
Pretrained VGG model Download from here, move it to models/
. This model is used to calculate training loss for the upsampler.
New models can be trained with the following commands.
-
Prepare dataset. Use
--dataroot
option to locate the directory of file lists, e.g../flist
, and specify the dataset name to train with--dataset_name
option. Identify the types and mask ratio using--mask_type
and--pconv_level
options. -
Train the transformer.
# To specify your own dataset or settings in the bash file.
bash train_bat.sh
Please note that some of the transformer settings are defined in train_bat.py
instead of options/
, and this script will take every available gpus for training, please define the GPUs via CUDA_VISIBLE_DEVICES
instead of --gpu_ids
, which is used for the upsampler.
- Train the upsampler.
# To specify your own dataset or settings in the bash file.
bash train_up.sh
The upsampler is typically trained by the low-resolution ground truth, we find that using some samples from the trained BAT might be helpful to improve the performance i.e. PSNR, SSIM. But the sampling process is quite time consuming, training with ground truth also could yield reasonable results.
If you find this code helpful for your research, please cite our papers.
@inproceedings{yu2021diverse,
title={Diverse Image Inpainting with Bidirectional and Autoregressive Transformers},
author={Yu, Yingchen and Zhan, Fangneng and Wu, Rongliang and Pan, Jianxiong and Cui, Kaiwen and Lu, Shijian and Ma, Feiying and Xie, Xuansong and Miao, Chunyan},
booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
year={2021}
}
This code borrows heavily from SPADE and minGPT, we apprecite the authors for sharing their codes.