Replication of
Tomas Jakab*, Ankush Gupta*, Hakan Bilen, Andrea Vedaldi (* equal contribution). Advances in Neural Information Processing Systems (NeurIPS) 2018.
and a partial replication of
Tejas Kulkarni, Ankush Gupta, Catalin Ionescu, Sebastian Borgeaud, Malcolm Reynolds, Andrew Zisserman, Volodymyr Mnih. Unsupervised Learning of Object Keypoints for Perception and Control. NeurIPS 2019. [arXiv].
in Pytorch by Duane
conditional image generation (Tomas and Ankush et al)
transporter network (Ankush et al)
requires python 3.6 and NVIDIA apex
windows not supported
git clone https://github.com/duanenielsen/keypoints
python3 -m venv ~/.venv/keypoints
. ~/.venv/keypoints/activate
cd keypoints
pip3 install .
NVIDIA apex is required to run, open an issue if you would like me to make it optional
follow the apex readme to install
learning keypoints on faces requires celeba dataset, download https://drive.google.com/open?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM
extract to directory as below...
keypoints/data
├── celeba-low
│ └── img_align_celeba
│ ├── 000001.jpg
│ ├── 000002.jpg
│ ├── 000003.jpg
pong example with 16 bit precision
python3 transporter.py --run_id 2 --config configs/transporter_pong_grey.yaml
if you dont have RTX card, or can't be bothered with mixed precision you can disable it, but you may need to adjust minibatch size, use the flags
--opt_level O0 --batch_size 16
if you get GPU memory errors, reduce batch size until it fits on your card
tensorboard files and checkpoints are saved to data/models
2 checkpoints are saved during runs
checkpoint - the latest version of the model during training best - the model that acheived the best test loss during training
python3 keypoints.py --run_id 3 --config configs/keypoints_celeba.yaml
python transporter.py --run_id 1 --config configs/transporter_celeba.yaml
python3 keypoints.py --run_id 1 --config configs/keypoints.yaml
run on a specific cuda device
--device cuda:1
run in 32 bit precision
--opt_level O0
display the run live, update display every 100 minibatches
--display --display_freq 100
load from checkpoint files in directory
--load data/models/VGG_PONG_LAYERNECK/run_1/checkpoint
run a saved model in demo mode (don't train, and display live results)
--load data/models/VGG_PONG_LAYERNECK/run_1/best --demo --display --display_freq 5