
unofficial pytorch implementation Anime Sketch Coloring with Swish-Gated Residual U-Net

Primary LanguagePython

Anime Sketch Coloring with Swish-Gated Residual U-Net

Pytorch unofficial port of SGRUnet(the official: here)

model arch


This is the performance of training 13 epochs, config is consistent with this config.py. google drive
Training this model takes a lot of time, so I only trained 13 epochs, which does not represent the best performance. example1 example2 log


  • LayerNorm requires a lot of memory, so BatchNorm was implemented, which greatly speeds up the training, but may have an impact on performance. You can choose which to use in config.py.
  • For save your memory, you can choose bilinear or transpose convolution(paper) to upsample.
  • Two datasets are supported. Anime Sketch Colorization Pair and another one that was used in the paper. Also optional in config.py.
  • The network that calculates the loss is different(ResNet family vs VGG family). Also optional in config.py.
  • Support mini-batch training.

Data folder


└── data/
    ├── train/
    └── val/

safebooru (paper)

└── data/
    ├── train/
    |    ├── img/
    |    └── label/
    └── val/
         ├── img/
         └── label/


pytorch >= 1.1.0

Use the requirements.txt file to install the necessary depedencies for this project.

$ pip install -r requirements.txt


Modify config.pyas needed.


python  main.py


Modify the

  • model_path
  • file_name
  • file_path
  • output_path

in the inference.pyas needed.


python inference.py