/RAMiT

Reciprocal Attention Mixing Transformer for Lightweight Image Restoration (CVPR 2024 Workshop)

Primary LanguagePython

RAMiT

Reciprocal Attention Mixing Transformer for Lightweight Image Restoration (CVPR 2024 Workshop NTIRE)

Haram Choi*, Cheolwoong Na, Jihyeon Oh, Seungjae Lee, Jinseop Kim, Subeen Choe, Jeongmin Lee, Taehoon Kim, and Jihoon Yang+

*: This work has been done during Master Course in Sogang University.

+: Corresponding author.

arXiv paper supplement visual poster

  • Proposes RAMiT which employs Dimensional Reciprocal Attention Mixing Transformer (D-RAMiT) and Hierarchical Reciprocal Attention Mixer (H-RAMi)
  • D-RAMiT: computing bi-dimensional self-attention in parallel to capture both local and global dependencies
  • H-RAMi: using multi-scale attention for considering where and how much attention to pay semantically and globally
  • Achieves state-of-the-art results on five lightweight image restoration tasks: Super-Resolution, Color Denoising, Grayscale Denoising, Low-Light Enhancement, Deraining

News

-June 07, 2024: Presentation poster available.

-April 17, 2024: Accepted at CVPR 2024 Workshop NTIRE (New Trends in Image Restoration and Enhancement)

-July 12, 2023: Codes released publicly

-May 19, 2023: Pre-printed at arXiv

Model Architecture

Click

image

Dimensional Reciprocal Self-Attentions

Click

image

Lightweight Image Restoration Results

Super-Resolution (SR)

image

slimSR

image

SR trade-off

image

Color Denoising (CDN)

image

Grayscale Denoising (GDN)

image

Low-Light Enhancement (LLE)

image

Deraining (DR)

image

Visual Results

* The visual results on the other images can be downloaded in my drive.

Super-Resolution (SR)

image

Color Denoising (CDN)

image

Low-Light Enhancement (LLE)

image

Deraining (DR)

image

Testing Instructions (with pre-trained models)

Please properly edit the first five arguments to work on your devices.

RAMiT SR

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_X2.pth --task lightweight_sr --target_mode light_x2 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_X3.pth --task lightweight_sr --target_mode light_x3 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_X4.pth --task lightweight_sr --target_mode light_x4 --result_image_save --img_norm

RAMiT-1 SR

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_X2.pth --task lightweight_sr --target_mode light_x2 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_X3.pth --task lightweight_sr --target_mode light_x3 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_X4.pth --task lightweight_sr --target_mode light_x4 --result_image_save --img_norm

RAMiT-slimSR SR

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --pretrain_path ./pretrained/RAMiT-slimSR_X2.pth --task lightweight_sr --target_mode light_x2 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --pretrain_path ./pretrained/RAMiT-slimSR_X3.pth --task lightweight_sr --target_mode light_x3 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --pretrain_path ./pretrained/RAMiT-slimSR_X4.pth --task lightweight_sr --target_mode light_x4 --result_image_save --img_norm

RAMiT CDN

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_CDN.pth --task lightweight_dn --target_mode light_dn --result_image_save --img_norm

RAMiT-1 CDN

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_CDN.pth --task lightweight_dn --target_mode light_dn --result_image_save --img_norm

RAMiT GDN

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_GDN.pth --task lightweight_dn --target_mode light_graydn --result_image_save --img_norm

RAMiT LLE

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_LLE.pth --task lightweight_lle --target_mode light_lle --result_image_save --img_norm

RAMiT-slimLLE LLE

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimLLE --pretrain_path ./pretrained/RAMiT-slimLLE_LLE.pth --task lightweight_lle --target_mode light_lle --result_image_save --img_norm

RAMiT DR

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_DR.pth --task lightweight_dr --target_mode light_dr --result_image_save --img_norm

Training Instructions

Please properly edit the first five arguments to work on your devices.

RAMiT SR
(x2) from scratch
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_x2 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 500 --half_list 200,300,400,425,450,475 --img_norm

(x3) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT --target_mode light_x3 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

(x4) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT --target_mode light_x4 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

RAMiT-1 SR

(x2) from scratch
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --target_mode light_x2 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 500 --half_list 200,300,400,425,450,475 --img_norm

(x3) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-1 --target_mode light_x3 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

(x4) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-1 --target_mode light_x4 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

RAMiT-slimSR SR

(x2) from scratch
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --target_mode light_x2 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 500 --half_list 200,300,400,425,450,475 --img_norm

(x3) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-slimSR --target_mode light_x3 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

(x4) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-slimSR --target_mode light_x4 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

RAMiT CDN (blind noise level)

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_dn --task lightweight_dn --sigma 0,50 --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DFBW --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT-1 CDN (blind noise level)

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --target_mode light_dn --task lightweight_dn --sigma 0,50 --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DFBW --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT GDN (blind noise level)

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_graydn --task lightweight_dn --sigma 0,50 --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DFBW --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT LLE

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_lle --task lightweight_lle --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name LLE --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT-slimLLE LLE

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimLLE --target_mode light_lle --task lightweight_lle --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name LLE --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT DR

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_dr --task lightweight_dr --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DR --total_epochs 400 --half_list 200,300,350,375 --img_norm

Citation

(preferred)
@inproceedings{choi2024reciprocal,
  title={Reciprocal Attention Mixing Transformer for Lightweight Image Restoration},
  author={Choi, Haram and Na, Cheolwoong and Oh, Jihyeon and Lee, Seungjae and Kim, Jinseop and Choe, Subeen and Lee, Jeongmin and Kim, Taehoon and Yang, Jihoon},
  booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
  pages={5992--6002}
  year={2024}
}

@article{choi2023reciprocal,
  title={Reciprocal Attention Mixing Transformer for Lightweight Image Restoration},
  author={Choi, Haram and Na, Cheolwoong and Oh, Jihyeon and Lee, Seungjae and Kim, Jinseop and Choe, Subeen and Lee, Jeongmin and Kim, Taehoon and Yang, Jihoon},
  journal={arXiv preprint arXiv:2305.11474},
  year={2023}
}

My Related Works

  • N-Gram in Swin Transformers for Efficient Lightweight Image Super-Resolution, CVPR 2023. proceedings arXiv code
  • Exploration of Lightweight Single Image Denoising with Transformers and Truly Fair Training, ICMR 2023. proceedings arXiv code