We expect you have an nvidia GPU and have installed CUDA. The code does not support running on CPU for now.
- Install torch and dependencies from https://github.com/torch/distro
- Install torch packages
nngraph
,cudnn
,display
luarocks install nngraph
luarocks install cudnn
luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec
- Clone this repo:
git clone https://github.com/Zhaoyi-Yan/Shift-Net
cd Shift-Net
bash scripts/download_models.sh
The model will be downloaded and unzipped.
-
Download your own dataset.
-
Change the options in
train.lua
according to your path of dataset. Normally, you should at least specify three options. They areDATA_ROOT
,phase
andname
.
For example:
DATA_ROOT
: ./datasets/Paris_StreetView_Dataset/
phase
: paris_train
name
: paris_train_shiftNet
This means that the training images are under the folder of ./datasets/Paris_StreetView_Dataset/paris_train/
.
As for name
, it gives your experiment a name, e.g., paris_train_shiftNet
. When training, the checkpoints are stored under the folder
./checkpoints/paris_train_shiftNet/
.
- Train a model:
th train.lua
- Display the temporary results on the browser.
Set
display = 1
, and then open another console,
th -ldisplay.start
- Open this URL in your browser: http://localhost:8000
Before test, you should change DATA_ROOT
, phase
, name
, checkpoint_dir
and which_epoch
.
For example, if you want to test the 30-th epoch of your trained model, then
DATA_ROOT
: ./datasets/Paris_StreetView_Dataset/
phase
: paris_train
name
: paris_train_shiftNet
checkpoint_dir
:./checkpoints/
which_epoch
: '30'
The first two options determine where the dataset is, and the rest define the folder where the model is stored.
- Finally, test the model:
th test.lua
We benefit a lot from pix2pix and DCGAN. The data loader is modified from pix2pix and the implemetation of Instance Normalization borrows from Instance Normalization. The shift operation is inspired by style-swap.