/BAM-A-lightweight-but-efficient-Balanced-attention-mechanism-for-super-resolution

we proposed a light-weigt but efficient blanced attention mechanism for SISR task.

Primary LanguagePythonMIT LicenseMIT

BAM (The project is a bit messy, I will clean up the code later)

This project is built from IDN, and thanks for the contributions of all the other researchers those who made their codes accessible.

Requirements

  • PyTorch>=1.0.0
  • Numpy 1.15.4
  • Pillow 5.4.1
  • h5py 2.8.0
  • tqdm 4.30.0

##result

图片

Train

The DIV2K, Set5 dataset converted to HDF5 can be downloaded from the links below.Otherwise, you can use prepare.py to create custom dataset.

Dataset Scale Type Link
DIV2K 2 Train Download
DIV2K 3 Train Download
DIV2K 4 Train Download
Set5 2 Eval Download
Set5 3 Eval Download
Set5 4 Eval Download

The Flickr2K dataset can be downloaded from the links below,and then you can use prepare.py to create custom dataset. https://link.csdn.net/?target=http%3A%2F%2Fcv.snu.ac.kr%2Fresearch%2FEDSR%2FFlickr2K.tar

for preparedata: python3 prepare.py --images-dir ../../DIV2K/DIV2K_train_HR --output-path ./h5file_DIV2K_train_HR_x4_train --scale 4 --eval False

for train: python3 train.py --choose_net DRLN_BlancedAttention --train_file ./h5file_mirflickr_train_HR_x3_train --eval_file ./h5file_Set5_x4_test

for eval all SR size && all networks(you should download checkpoints first); python3 eval_allsize_allnet.py

for eval dingle image: python3 eval_singleimg.py --lr_image_file ./savedimg/Set5/4/EDSR_blanced_attention_2.png --hr_image_file ../classical_SR_datasets/Set5/Set5/butterfly.png

for infer all size && all networks SR images(the SR images will be saved in the direct ./savedimg/*): python3 infer_allsize_allnet.py

##checkpoints We provide all network && all size checkpoints to prove that our experiments are convincing. you can get them from:https://pan.baidu.com/s/1gy-3jcikT2h-QfRduwoibg password: 2ubm

If the password fails or any other questions, please contact me:2463908977@qq.com

##Our Attention mechanism is very tiny and efficient, and has also been proved to be efficient in semantic segmentation missions,especially for light-weight models. class ChannelAttention(nn.Module): def init(self, in_planes, ratio=16): super(ChannelAttention, self).init() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) self.relu1 = nn.PReLU(in_planes // ratio) self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) out = avg_out return self.sigmoid(out)

class SpatialAttention(nn.Module): def init(self, kernel_size=7): super(SpatialAttention, self).init() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): max_out, _ = torch.max(x, dim=1, keepdim=True) x = self.conv1(max_out) return self.sigmoid(x) class BlancedAttention(nn.Module): def init(self, in_planes, reduction=16): super(BlancedAttention, self).init() self.ca = ChannelAttention(in_planes, reduction) self.sa = SpatialAttention() def forward(self, x): ca_ch = self.ca(x) sa_ch = self.sa(x) out=ca_ch.mul(sa_ch)*x return out