/SGRUnet-pytorch

unofficial pytorch implementation Anime Sketch Coloring with Swish-Gated Residual U-Net

Primary LanguagePython

Anime Sketch Coloring with Swish-Gated Residual U-Net

Pytorch unofficial port of SGRUnet(the official: here)

model arch

Performance

This is the performance of training 13 epochs, config is consistent with this config.py. google drive
Training this model takes a lot of time, so I only trained 13 epochs, which does not represent the best performance. example1 example2 log

Feature

  • LayerNorm requires a lot of memory, so BatchNorm was implemented, which greatly speeds up the training, but may have an impact on performance. You can choose which to use in config.py.
  • For save your memory, you can choose bilinear or transpose convolution(paper) to upsample.
  • Two datasets are supported. Anime Sketch Colorization Pair and another one that was used in the paper. Also optional in config.py.
  • The network that calculates the loss is different(ResNet family vs VGG family). Also optional in config.py.
  • Support mini-batch training.

Data folder

colorgram

anime_colorization
└── data/
    ├── train/
    |   
    └── val/

safebooru (paper)

anime_colorization
└── data/
    ├── train/
    |    ├── img/
    |    └── label/
    └── val/
         ├── img/
         └── label/

Setup

pytorch >= 1.1.0

Use the requirements.txt file to install the necessary depedencies for this project.

$ pip install -r requirements.txt

Config

Modify config.pyas needed.

Train

python  main.py

Inference

Modify the

  • model_path
  • file_name
  • file_path
  • output_path

in the inference.pyas needed.

and

python inference.py