/LearnToPayAttention

PyTorch implementation of ICLR 2018 paper Learn To Pay Attention (and some modification)

Primary LanguagePythonGNU General Public License v3.0GPL-3.0

LearnToPayAttention

AUR

PyTorch implementation of ICLR 2018 paper Learn To Pay Attention


Most Recent Updates


My implementation is based on "(VGG-att3)-concat-pc" in the paper, and I trained the model on CIFAR-100 DATASET.
I implemented two version of the model, the only difference is whether to insert the attention module before or after the corresponding max-pooling layer.

(New!) Pre-trained models

Google drive link
Alternative link(Baidu Cloud Disk)

Dependences

NOTE If you are using PyTorch < 0.4.1, then replace torch.nn.functional.interpolate by torch.nn.Upsample. (Modify the code in utilities.py).

Training

  1. Pay attention before max-pooling layers
python train.py --attn_mode before --outf logs_before --normalize_attn --log_images
  1. Pay attention after max-pooling layers
python train.py --attn_mode after --outf logs_after --normalize_attn --log_images

Results

Training curve - loss

The x-axis is # iter

  1. Pay attention before max-pooling layers

  2. Pay attention after max-pooling layers

  3. Plot in one figure

Training curve - accuracy on test data

The x-axis is # epoch

  1. Pay attention before max-pooling layers

  2. Pay attention after max-pooling layers

  3. Plot in one figure

Quantitative results (on test data of CIFAR-100)

Method VGG (Simonyan&Zisserman,2014) (VGG-att3)-concat-pc (ICLR 2018) attn-before-pooling (my code) attn-after-pooling (my code)
Top-1 error 30.62 22.97 22.62 22.92

Attention map visualization (on test data of CIFAR-100)

From left to right: L1, L2, L3, original images

  1. Pay attention before max-pooling layers

  2. Pay attention after max-pooling layers