PyTorch implementation of ICLR 2018 paper Learn To Pay Attention
- Oct. 29, 2018: Add the implementation of "grid attention" module (NOT tested on CIFAR100 or any other dataset. Feel free to do your own experiments).
Reference paper: https://arxiv.org/abs/1804.05338
Reference code: https://github.com/ozan-oktay/Attention-Gated-Networks - Nov. 2, 2018: Release the pre-trained models
My implementation is based on "(VGG-att3)-concat-pc" in the paper, and I trained the model on CIFAR-100 DATASET.
I implemented two version of the model, the only difference is whether to insert the attention module before or after the corresponding max-pooling layer.
Google drive link
Alternative link(Baidu Cloud Disk)
- PyTorch (>=0.4.1)
- OpenCV
- tensorboardX
NOTE If you are using PyTorch < 0.4.1, then replace torch.nn.functional.interpolate by torch.nn.Upsample. (Modify the code in utilities.py).
- Pay attention before max-pooling layers
python train.py --attn_mode before --outf logs_before --normalize_attn --log_images
- Pay attention after max-pooling layers
python train.py --attn_mode after --outf logs_after --normalize_attn --log_images
The x-axis is # iter
The x-axis is # epoch
Method | VGG (Simonyan&Zisserman,2014) | (VGG-att3)-concat-pc (ICLR 2018) | attn-before-pooling (my code) | attn-after-pooling (my code) |
---|---|---|---|---|
Top-1 error | 30.62 | 22.97 | 22.62 | 22.92 |
From left to right: L1, L2, L3, original images