/Stand-Alone-Self-Attention

Implementing Stand-Alone Self-Attention in Vision Models using Pytorch

Primary LanguagePythonMIT LicenseMIT

Implementing Stand-Alone Self-Attention in Vision Models using Pytorch (13 Jun 2019)

  • Stand-Alone Self-Attention in Vision Models paper
  • Author:
    • Prajit Ramachandran (Google Research, Brain Team)
    • Niki Parmar (Google Research, Brain Team)
    • Ashish Vaswani (Google Research, Brain Team)
    • Irwan Bello (Google Research, Brain Team)
    • Anselm Levskaya (Google Research, Brain Team)
    • Jonathon Shlens (Google Research, Brain Team)
  • Awesome :)

Method

  • Attention Layer

    • Equation 1:

      CodeCogsEqn (2)

  • Relative Position Embedding

    • The row and column offsets are associated with an embedding CodeCogsEqn (3) and CodeCogsEqn (4) respectively each with dimension CodeCogsEqn (5). The row and column offset embeddings are concatenated to form CodeCogsEqn (6). This spatial-relative attention is now defined as below equation.

    • Equation 2:

      CodeCogsEqn (7)

    • I refer to the following paper when implementing this part.

  1. Replacing Spatial Convolutions
    - A 2 × 2 average pooling with stride 2 operation follows the attention layer whenever spatial downsampling is required. - This work applies the transform on the ResNet family of architectures. The proposed transform swaps the 3 × 3 spatial convolution with a self-attention layer as defined in Equation 3.
  2. Replacing the Convolutional Stem
    - The initial layers of a CNN, sometimes referred to as the stem, play a critical role in learning local features such as edges, which later layers use to identify global objects. - The stem performs self-attention within each 4 × 4 spatial block of the original image, followed by batch normalization and a 4 × 4 max pool operation.

Experiments

Setup

  • Spatial extent: 7
  • Attention heads: 8
  • Layers:
    • ResNet 26: [1, 2, 4, 1]
    • ResNet 38: [2, 3, 5, 2]
    • ResNet 50: [3, 4, 6, 3]
Datasets Model Accuracy Parameters (My Model, Paper Model)
CIFAR-10 ResNet 26 90.94% 8.30M, -
CIFAR-10 Naive ResNet 26 94.29% 8.74M
CIFAR-10 ResNet 26 + stem 90.22% 8.30M, -
CIFAR-10 ResNet 38 (WORK IN PROCESS) 89.46% 12.1M, -
CIFAR-10 Naive ResNet 38 94.93% 15.0M
CIFAR-10 ResNet 50 (WORK IN PROCESS) 16.0M, -
IMAGENET ResNet 26 (WORK IN PROCESS) 10.3M, 10.3M
IMAGENET ResNet 38 (WORK IN PROCESS) 14.1M, 14.1M
IMAGENET ResNet 50 (WORK IN PROCESS) 18.0M, 18.0M

Usage

Requirements

  • torch==1.0.1

Todo

  • Experiments
  • IMAGENET
  • Review relative position embedding, attention stem
  • Code Refactoring

Reference