This repositity is our Pytorch implementation for Shift-Net, it is just for those who are interesting in our work and want to get a skeleton Pytorch implemention. The original code is https://github.com/Zhaoyi-Yan/Shift-Net. I will upload pytorch models in months(Sorry for the delay).
- Linux or OSX.
- Python 2 or Python 3.
- CPU or NVIDIA GPU + CUDA CuDNN.
- Tested on pytorch 0.3
- Install PyTorch and dependencies from http://pytorch.org/
- Install python libraries visdom and dominate.
pip install visdom
pip install dominate
- Clone this repo:
git clone https://github.com/Zhaoyi-Yan/Shift-Net_pytorch
cd Shift-Net_pytorch
-
Download your own inpainting datasets.
-
Train a model:
python train.py
- To view training results and loss plots, run
python -m visdom.server
and click the URL http://localhost:8097. - Test the model
python test.py
The test results will be saved to a html file here: ./results/
.
If you find this work useful, please cite:
@InProceedings{Yan_2018_Shift,
author = {Yan, Zhaoyi and Li, Xiaoming and Li, Mu and Zuo, Wangmeng and Shan, Shiguang},
title = {Shift-Net: Image Inpainting via Deep Feature Rearrangement},
booktitle = {The European Conference on Computer Vision (ECCV)},
month = {September},
year = {2018}
}
This verison of code makes that InnerCos
DO NOT support single gpu training. It is weried and I have not idea to
solve it now. If you train our model in a single GPU, please set skip=1
in options/base_options
. Otherwise,
when training with InnerCos
working on a single GPU, please refer the code before commit aa2382b
If you find it a little hard to read the code, you may read Guides.
- Make U-Net handle with inputs of any sizes. (By resizing the size of features of decoder to fit that of the corresponding features of decoder.
- Update the code for pytorch >= 0.4.
- Clean the code and delete useless comments.
- Lots of more crazy operations that help in performance.
- Guides of our code, we hope it helps you understand our code more easily.
- Directly resize the mask to save computation.
- Guidance loss seems defined on the global region. Need make it work only in masked region.
We benefit a lot from pytorch-CycleGAN-and-pix2pix