This is the reimplementation code of CVPR'2018 paper Learning to Sketch with Shortcut Cycle Consistency.
Photo | Generated examples |
---|---|
-
Python 3
-
Tensorflow (>= 1.4.0)
-
InkScape or CairoSVG (For vector sketch rendering. Choose one of them is ok.)
sudo apt-get install inkscape # or pip3 install cairosvg
From the paper, we need to pre-train the model on the QuickDraw dataset. So we need to preprocess both the QuickDraw-shoes and QMUL-shoes data following these steps:
-
QuickDraw-shoes
- Download the
sketchrnn_shoes.npz
data from QuickDraw - Place the package under
datasets/QuickDraw/shoes/npz/
directory - run the command:
python quickdraw_data_processing.py
- Download the
-
QMUL-shoes
- Download the photo data from QMUL-Shoe-Chair-V2
- Unzip the ShoeV2_photo package and place all
.png
underdatasets/QMUL/shoes/photos/
directory - Download the preprocessed sketch data from here and place the two
.h5
packages underdatasets/QMUL/shoes/
directory
-
QuickDraw-shoes pre-training
- Change the value to
QuickDraw
inmodel.py
-get_default_hparams
-data_type
- run the command:
python sketch_p2s_train.py
- Change the value to
-
QMUL-shoes training
- Change the value to
QMUL
inmodel.py
-get_default_hparams
-data_type
- Make sure the QuickDraw-shoes pre-training models/checkpoint are placed under
outputs/snapshot/
directory - Change the value to
True
insketch_p2s_train.py
-resume_training
- run the command:
python sketch_p2s_train.py
- Change the value to
The following figure shows the total loss, KL loss and reconstruction loss during training with QuickDraw-shoes pre-trained within 30k iterations and the following QMUL-shoes trained within 40k iterations.
-
QuickDraw-shoes
- Make sure the value of
data_type
to beQuickDraw
inmodel.py
- Place models/checkpoint/config under
outputs/snapshot/QuickDraw/
directory - run the command:
python sketch_p2s_sampling.py
- Make sure the value of
-
QMUL-shoes
- Make sure the value of
data_type
to beQMUL
inmodel.py
- Place models/checkpoint/config under
outputs/snapshot/QMUL/
directory - run the command:
python sketch_p2s_sampling.py
- Make sure the value of
All results can be found under outputs/sampling/
dir.
- This code is largely borrowed from repos Sketch-RNN and deep_p2s.