This repository contains the official implementation of Locally Enhanced Self-Attention: Combining Self-Attention and Convolution as Local and Context Terms. The code for image classification and object detection is based on axial-deeplab and mmdetection.
Visualizing Locally Enhanced Self-Attention (LESA) at one spatial location.
Self-Attention has become prevalent in computer vision models. Inspired by fully connected Conditional Random Fields (CRFs), we decompose self-attention into local and context terms. They correspond to the unary and binary terms in CRF and are implemented by attention mechanisms with projection matrices. We observe that the unary terms only make small contributions to the outputs, and meanwhile standard CNNs that rely solely on the unary terms achieve great performances on a variety of tasks. Therefore, we propose Locally Enhanced Self-Attention (LESA), which enhances the unary term by incorporating it with convolutions, and utilizes a fusion module to dynamically couple the unary and binary operations. In our experiments, we replace the self-attention modules with LESA. The results on ImageNet and COCO show the superiority of LESA over convolution and self-attention baselines for the tasks of image recognition, object detection, and instance segmentation.
Image | Convolution | Self-Attention | LESA |
---|---|---|---|
Effectiveness of Locally Enhanced Self-Attention(LESA) on COCO object detection and instance segmentation.
If you find LESA is helpful in your project, please consider citing our paper.
@article{yang2021locally,
title={Locally Enhanced Self-Attention: Rethinking Self-Attention as Local and Context Terms},
author={Yang, Chenglin and Qiao, Siyuan and Kortylewski, Adam and Yuille, Alan},
journal={arXiv preprint arXiv:2107.05637},
year={2021}
}
Please refer to LESA_classification for details.
Method | Model | Top-1 Acc. | Top-5 Acc. |
---|---|---|---|
LESA_ResNet50 | Download | 79.55 | 94.79 |
LESA_WRN50 | Download | 80.18 | 95.07 |
Please refer to LESA_detection for details.
Method | Backbone | Pretrained | Model | Box AP | Mask AP |
---|---|---|---|---|---|
Mask-RCNN | LESA_ResNet50 | Download | Download | 44.2 | 39.6 |
HTC | LESA_WRN50 | Download | Download | 50.5 | 44.4 |
This project is based on axial-deeplab and mmdetection.
Relative position embedding is based on bottleneck-transformer-pytorch
ResNet is based on pytorch/vision. Classification helper functions are based on pytorch-classification.