Reciprocal Attention Mixing Transformer for Lightweight Image Restoration (CVPR 2024 Workshop NTIRE)
Haram Choi*, Cheolwoong Na, Jihyeon Oh, Seungjae Lee, Jinseop Kim, Subeen Choe, Jeongmin Lee, Taehoon Kim, and Jihoon Yang+
*: This work has been done during Master Course in Sogang University.
+: Corresponding author.
- Proposes RAMiT which employs Dimensional Reciprocal Attention Mixing Transformer (D-RAMiT) and Hierarchical Reciprocal Attention Mixer (H-RAMi)
- D-RAMiT: computing bi-dimensional self-attention in parallel to capture both local and global dependencies
- H-RAMi: using multi-scale attention for considering where and how much attention to pay semantically and globally
- Achieves state-of-the-art results on five lightweight image restoration tasks: Super-Resolution, Color Denoising, Grayscale Denoising, Low-Light Enhancement, Deraining
-June 07, 2024: Presentation poster available.
-April 17, 2024: Accepted at CVPR 2024 Workshop NTIRE (New Trends in Image Restoration and Enhancement)
-July 12, 2023: Codes released publicly
-May 19, 2023: Pre-printed at arXiv
* The visual results on the other images can be downloaded in my drive.
Please properly edit the first five arguments to work on your devices.
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_X2.pth --task lightweight_sr --target_mode light_x2 --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_X3.pth --task lightweight_sr --target_mode light_x3 --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_X4.pth --task lightweight_sr --target_mode light_x4 --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_X2.pth --task lightweight_sr --target_mode light_x2 --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_X3.pth --task lightweight_sr --target_mode light_x3 --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_X4.pth --task lightweight_sr --target_mode light_x4 --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --pretrain_path ./pretrained/RAMiT-slimSR_X2.pth --task lightweight_sr --target_mode light_x2 --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --pretrain_path ./pretrained/RAMiT-slimSR_X3.pth --task lightweight_sr --target_mode light_x3 --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --pretrain_path ./pretrained/RAMiT-slimSR_X4.pth --task lightweight_sr --target_mode light_x4 --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_CDN.pth --task lightweight_dn --target_mode light_dn --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_CDN.pth --task lightweight_dn --target_mode light_dn --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_GDN.pth --task lightweight_dn --target_mode light_graydn --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_LLE.pth --task lightweight_lle --target_mode light_lle --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimLLE --pretrain_path ./pretrained/RAMiT-slimLLE_LLE.pth --task lightweight_lle --target_mode light_lle --result_image_save --img_norm
python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_DR.pth --task lightweight_dr --target_mode light_dr --result_image_save --img_norm
Please properly edit the first five arguments to work on your devices.
(x2) from scratch
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_x2 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 500 --half_list 200,300,400,425,450,475 --img_norm
(x3) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT --target_mode light_x3 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm
(x4) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT --target_mode light_x4 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm
(x2) from scratch
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --target_mode light_x2 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 500 --half_list 200,300,400,425,450,475 --img_norm
(x3) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-1 --target_mode light_x3 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm
(x4) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-1 --target_mode light_x4 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm
(x2) from scratch
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --target_mode light_x2 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 500 --half_list 200,300,400,425,450,475 --img_norm
(x3) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-slimSR --target_mode light_x3 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm
(x4) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-slimSR --target_mode light_x4 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_dn --task lightweight_dn --sigma 0,50 --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DFBW --total_epochs 400 --half_list 200,300,350,375 --img_norm
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --target_mode light_dn --task lightweight_dn --sigma 0,50 --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DFBW --total_epochs 400 --half_list 200,300,350,375 --img_norm
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_graydn --task lightweight_dn --sigma 0,50 --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DFBW --total_epochs 400 --half_list 200,300,350,375 --img_norm
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_lle --task lightweight_lle --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name LLE --total_epochs 400 --half_list 200,300,350,375 --img_norm
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimLLE --target_mode light_lle --task lightweight_lle --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name LLE --total_epochs 400 --half_list 200,300,350,375 --img_norm
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_dr --task lightweight_dr --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DR --total_epochs 400 --half_list 200,300,350,375 --img_norm
(preferred)
@inproceedings{choi2024reciprocal,
title={Reciprocal Attention Mixing Transformer for Lightweight Image Restoration},
author={Choi, Haram and Na, Cheolwoong and Oh, Jihyeon and Lee, Seungjae and Kim, Jinseop and Choe, Subeen and Lee, Jeongmin and Kim, Taehoon and Yang, Jihoon},
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
pages={5992--6002}
year={2024}
}
@article{choi2023reciprocal,
title={Reciprocal Attention Mixing Transformer for Lightweight Image Restoration},
author={Choi, Haram and Na, Cheolwoong and Oh, Jihyeon and Lee, Seungjae and Kim, Jinseop and Choe, Subeen and Lee, Jeongmin and Kim, Taehoon and Yang, Jihoon},
journal={arXiv preprint arXiv:2305.11474},
year={2023}
}