Chen-Hsuan Lin and Simon Lucey
IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017 (oral presentation)
Paper: https://www.andrew.cmu.edu/user/chenhsul/paper/CVPR2017.pdf
arXiv preprint: https://arxiv.org/abs/1612.03897
We provide TensorFlow code for the following experiments:
- MNIST classification
- traffic sign classification
[NEW!] The PyTorch implementation of the MNIST experiment is now up!
This code is developed with Python3 (python3
) but it is also compatible with Python2.7 (python
). TensorFlow r1.0+ is required. The dependencies can install by running
pip3 install --upgrade numpy scipy termcolor matplotlib tensorflow-gpu
If you're using Python2.7, use pip2
instead; if you don't have sudo access, add the --user
flag.
The training code can be executed via the command
python3 train.py <netType> [(options)]
<netType>
should be one of the following:
CNN
- standard convolutional neural networkSTN
- Spatial Transformer Network (STN)IC-STN
- Inverse Compositional Spatial Transformer Network (IC-STN)
The list of optional arguments can be found by executing python3 train.py --help
.
The default training settings in this released code is slightly different from that in the paper; it is stabler and optimizes the networks better.
When the code is run for the first time, the datasets will be automatically downloaded and preprocessed.
The checkpoints are saved in the automatically created directory model_GROUP
; summaries are saved in summary_GROUP
.
We've included code to visualize the training over TensorBoard. To execute, run
tensorboard --logdir=summary_GROUP --port=6006
We provide three types of data visualization:
- SCALARS: training/test error over iterations
- IMAGES: alignment results and mean/variance appearances
- GRAPH: network architecture
The PyTorch version of the code is stil under active development. The training speed is currently slower than the TensorFlow version. Suggestions on improvements are welcome! :)
This code is developed with Python3 (python3
). It has not been tested with Python2.7 yet. PyTorch 0.2.0+ is required. Please see http://pytorch.org/ for installation instructions.
Visdom is also required; it can be installed by running
pip3 install --upgrade visdom
If you don't have sudo access, add the --user
flag.
First, start a Visdom server by running
python3 -m visdom.server -port=7000
The training code can be executed via the command (using the same port number)
python3 train.py <netType> --port=7000 [(options)]
<netType>
should be one of the following:
CNN
- standard convolutional neural networkSTN
- Spatial Transformer Network (STN)IC-STN
- Inverse Compositional Spatial Transformer Network (IC-STN)
The list of optional arguments can be found by executing python3 train.py --help
.
The default training settings in this released code is slightly different from that in the paper; it is stabler and optimizes the networks better.
When the code is run for the first time, the datasets will be automatically downloaded and preprocessed.
The checkpoints are saved in the automatically created directory model_GROUP
; summaries are saved in summary_GROUP
.
We provide three types of data visualization on Visdom:
- Training/test error over iterations
- Alignment results and mean/variance appearances
If you find our code useful for your research, please cite
@article{lin2017inverse,
title={Inverse Compositional Spatial Transformer Networks},
author={Lin, Chen-Hsuan and Lucey, Simon},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition ({CVPR})},
year={2017}
}
Please contact me (chlin@cmu.edu) if you have any questions!