/MST_inpainting

Learning a Sketch Tensor Space for Image Inpainting of Man-made Scenes (ICCV 2021)

Primary LanguagePythonMIT LicenseMIT

Learning a Sketch Tensor Space for Image Inpainting of Man-made Scenes (ICCV 2021)

Chenjie Cao, Yanwei Fu

LICENSE

teaser arXiv | Project Page

Overview

teaser We learn an encoder-decoder model, which encodes a Sketch Tensor (ST) space consisted of refined lines and edges. Then the model recover the masked images by the ST space.

News

  • Release the inference codes.
  • Training codes.

Now, this work has been further improved in ZITS (CVPR2022).

Preparation

  1. Preparing the environment.
  2. Download the pretrained masked wireframe detection model LSM-HAWP (retrained from HAWP CVPR2020).
  3. Download weights for different requires to the 'check_points' fold. P2M (Man-made Places2), P2C (Comprehensive Places2), shanghaitech (Shanghaitech with all man-made scenes).
  4. For training, we provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training (flist_example.txt).

Training

Since the training code is rewritten, there are some differences compared with the test code.

  1. Training uses src/models.py while testing uses src/model_inference.py.

  2. Image are valued in -1 to 1 (training) and 0 to 1 (testing).

  3. Masks are always concated to the inputs.

  1. Generating wireframes by lsm-hawp.
CUDA_VISIBLE_DEVICES=0 python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path>
  1. Setting file lists in training_configs/config_MST.yml (example: flist_example.txt).

  2. Train the inpainting model with stage1 and stage2.

python train_MST_stage1.py --path <model_name> --config training_configs/config_MST.yml --gpu 0
python train_MST_stage2.py --path <model_name> --config training_configs/config_MST.yml --gpu 0

For DDP training with multi-gpus:

python -m torch.distributed.launch --nproc_per_node=4 train_MST_stage1.py --path <model_name> --config training_configs/config_MST.yml --gpu 0,1,2,3
python -m torch.distributed.launch --nproc_per_node=4 train_MST_stage2.py --path <model_name> --config training_configs/config_MST.yml --gpu 0,1,2,3

Test for a single image

python test_single.py --gpu_id 0 \
                      --PATH ./check_points/MST_P2C \
                      --image_path <your image path> \
                      --mask_path <your mask path (0 means valid and 255 means masked)>

Object Removal Examples

Object removal video

Comparisons

ST Places2