/colorization-pytorch

PyTorch reimplementation of Interactive Deep Colorization

Primary LanguagePythonMIT LicenseMIT

Interactive Deep Colorization in PyTorch

This is our PyTorch reimplementation for interactive image colorization, written by Richard Zhang and Jun-Yan Zhu.

This repository contains training usage. The original, official GitHub repo (with an interactive GUI, and originally Caffe backend) is here. The official repo has been updated to support PyTorch models on the backend, which can be trained in this repository.

Prerequisites

  • Linux or macOS
  • Python 2 or 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

pip install -r requirements.txt
  • Clone this repo:
git clone https://github.com/richzhang/colorization-pytorch
cd colorization-pytorch

Dataset preparation

  • Download the ILSVRC 2012 dataset and run the following script to prepare data python make_ilsvrc_dataset.py --in_path /PATH/TO/ILSVRC12. This will make symlinks into the training set, and divide the ILSVRC validation set into validation and test splits for colorization.

Training interactive colorization

  • Train a model: bash ./scripts/train_siggraph.sh. This is a 2 stage training process. First, the network is trained for automatic colorization using classification loss. Results are in ./checkpoints/siggraph_class. Then, the network is fine-tuned for interactive colorization using regression loss. Final results are in ./checkpoints/siggraph_reg2.

  • To view training results and loss plots, run python -m visdom.server and click the URL http://localhost:8097. The following values are monitored:

    • G_CE is a cross-entropy loss between predicted color distribution and ground truth color.
    • G_entr is the entropy of the predicted distribution.
    • G_entr_hint is the entropy of the predicted distribution at points where a color hint is given.
    • G_L1_max is the L1 distance between the ground truth color and argmax of the predicted color distribution.
    • G_L1_mean is the L1 distance between the ground truth color and mean of the predicted color distribution.
    • G_L1_reg is the L1 distance between the ground truth color and the predicted color.
    • G_fake_real is the L1 distance between the predicted color and the ground truth color (in locations where a hint is given).
    • G_fake_hint is the L1 distance between the predicted color and the input hint color (in locations where a hint is given). It's a measure of how much the network "trusts" the input hint.
    • G_real_hint is the L1 distance between the ground truth color and the input hint color (in locations where a hint is given).

Testing interactive colorization

  • Get a model. Either:

    • (1) download the pretrained model by running bash pretrained_models/download_siggraph_model.sh, which will give you a model in ./checkpoints/siggraph_pretrained/latest_net_G.pth. Use siggraph_pretrained as [[NAME]] below.
    • (2) train your own model (as described in the section above), which will leave a model in ./checkpoints/siggraph_reg2/latest_net_G.pth. In this case, use siggraph_reg2 as [[NAME]] below.
  • Test the model on validation data: bash python test.py --name [[NAME]] , where [[NAME]] is siggraph_reg2 or siggraph_pretrained. The test results will be saved to an HTML file in ./results/[[NAME]]/latest_val/index.html. For each image in the validation set, it will test (1) automatic colorization, (2) interactive colorization with a few random hints, and (3) interactive colorization with lots of random hints.

  • Test the model by making PSNR vs. the number of hints plot: bash python test_sweep.py --name [[NAME]] . This plot was used in Figure 6 of the paper. This test randomly reveals 6x6 color hint patches to the network and sees how accurate the colorization is with respect to the ground truth.

  • Test the model interactively with the original official repository. Follow installation instructions in that repo and run python ideepcolor.py --backend pytorch --color_model [[PTH/TO/MODEL]] --dist_model [[PTH/TO/MODEL]].

Citation

If you use this code for your research, please cite our paper:

@article{zhang2017real,
  title={Real-Time User-Guided Image Colorization with Learned Deep Priors},
  author={Zhang, Richard and Zhu, Jun-Yan and Isola, Phillip and Geng, Xinyang and Lin, Angela S and Yu, Tianhe and Efros, Alexei A},
  journal={ACM Transactions on Graphics (TOG)},
  volume={9},
  number={4},
  year={2017},
  publisher={ACM}
}

Acknowledgments

This code borrows heavily from the pytorch-CycleGAN repository.