/sketch-photo2seq

Reimplementation of paper "Learning to Sketch with Shortcut Cycle Consistency"(CVPR 2018)

Primary LanguagePython

sketch-photo2seq

This is the reimplementation code of CVPR'2018 paper Learning to Sketch with Shortcut Cycle Consistency.

Photo Generated examples
example1 example1-sketch
example2 example2-sketch

Requirements

  • Python 3

  • Tensorflow (>= 1.4.0)

  • InkScape or CairoSVG (For vector sketch rendering. Choose one of them is ok.)

    sudo apt-get install inkscape
    # or
    pip3 install cairosvg
    

Data Preparations

From the paper, we need to pre-train the model on the QuickDraw dataset. So we need to preprocess both the QuickDraw-shoes and QMUL-shoes data following these steps:

  1. QuickDraw-shoes

    • Download the sketchrnn_shoes.npz data from QuickDraw
    • Place the package under datasets/QuickDraw/shoes/npz/ directory
    • run the command:
      python quickdraw_data_processing.py
      
  2. QMUL-shoes

    • Download the photo data from QMUL-Shoe-Chair-V2
    • Unzip the ShoeV2_photo package and place all .png under datasets/QMUL/shoes/photos/ directory
    • Download the preprocessed sketch data from here and place the two .h5 packages under datasets/QMUL/shoes/ directory

Training

  1. QuickDraw-shoes pre-training

    • Change the value to QuickDraw in model.py-get_default_hparams-data_type
    • run the command:
      python sketch_p2s_train.py
      
  2. QMUL-shoes training

    • Change the value to QMUL in model.py-get_default_hparams-data_type
    • Make sure the QuickDraw-shoes pre-training models/checkpoint are placed under outputs/snapshot/ directory
    • Change the value to True in sketch_p2s_train.py-resume_training
    • run the command:
      python sketch_p2s_train.py
      

Training loss

The following figure shows the total loss, KL loss and reconstruction loss during training with QuickDraw-shoes pre-trained within 30k iterations and the following QMUL-shoes trained within 40k iterations.

loss

Sampling

  1. QuickDraw-shoes

    • Make sure the value of data_type to be QuickDraw in model.py
    • Place models/checkpoint/config under outputs/snapshot/QuickDraw/ directory
    • run the command:
      python sketch_p2s_sampling.py
      
  2. QMUL-shoes

    • Make sure the value of data_type to be QMUL in model.py
    • Place models/checkpoint/config under outputs/snapshot/QMUL/ directory
    • run the command:
      python sketch_p2s_sampling.py
      

All results can be found under outputs/sampling/ dir.

Credits