/RGT

PyTorch code for our ICLR 2024 paper "Recursive Generalization Transformer for Image Super-Resolution"

Primary LanguagePythonApache License 2.0Apache-2.0

Recursive Generalization Transformer for Image Super-Resolution

Zheng Chen, Yulun Zhang, Jinjin Gu, Linghe Kong, and Xiaokang Yang, "Recursive Generalization Transformer for Image Super-Resolution", ICLR, 2024

[paper] [arXiv] [supplementary material] [visual results] [pretrained models]

🔥🔥🔥 News

  • 2024-02-04: Code and pre-trained models are released. 🎊🎊🎊
  • 2023-09-29: This repo is released.

Abstract: Transformer architectures have exhibited remarkable performance in image superresolution (SR). Since the quadratic computational complexity of the selfattention (SA) in Transformer, existing methods tend to adopt SA in a local region to reduce overheads. However, the local design restricts the global context exploitation, which is crucial for accurate image reconstruction. In this work, we propose the Recursive Generalization Transformer (RGT) for image SR, which can capture global spatial information and is suitable for high-resolution images. Specifically, we propose the recursive-generalization self-attention (RG-SA). It recursively aggregates input features into representative feature maps, and then utilizes cross-attention to extract global information. Meanwhile, the channel dimensions of attention matrices ($query$, $key$, and $value$) are further scaled to mitigate the redundancy in the channel domain. Furthermore, we combine the RG-SA with local self-attention to enhance the exploitation of the global context, and propose the hybrid adaptive integration (HAI) for module integration. The HAI allows the direct and effective fusion between features at different levels (local or global). Extensive experiments demonstrate that our RGT outperforms recent state-of-the-art methods quantitatively and qualitatively.


HR LR SwinIR CAT RGT (ours)

⚙️ Dependencies

  • Python 3.8
  • PyTorch 1.9.0
  • NVIDIA GPU + CUDA
# Clone the github repo and go to the default directory 'RGT'.
git clone https://github.com/zhengchen1999/RGT.git
conda create -n RGT python=3.8
conda activate RGT
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
python setup.py develop

⚒️ TODO

  • Release code and pretrained models

🔗 Contents

  1. Datasets
  2. Models
  3. Training
  4. Testing
  5. Results
  6. Citation
  7. Acknowledgements

🖨️ Datasets

Used training and testing sets can be downloaded as follows:

Training Set Testing Set Visual Results
DIV2K (800 training images, 100 validation images) + Flickr2K (2650 images) [complete training dataset DF2K: Google Drive / Baidu Disk] Set5 + Set14 + BSD100 + Urban100 + Manga109 [complete testing dataset: Google Drive / Baidu Disk] Google Drive / Baidu Disk

Download training and testing datasets and put them into the corresponding folders of datasets/. See datasets for the detail of the directory structure.

📦 Models

Method Params (M) FLOPs (G) PSNR (dB) SSIM Model Zoo Visual Results
RGT-S 10.20 193.08 27.89 0.8347 Google Drive / Baidu Disk Google Drive / Baidu Disk
RGT 13.37 251.07 27.98 0.8369 Google Drive / Baidu Disk Google Drive / Baidu Disk

The performance is reported on Urban100 (x4). Output size of FLOPs is 3×512×512.

🔧 Training

  • Download training (DF2K, already processed) and testing (Set5, Set14, BSD100, Urban100, Manga109, already processed) datasets, place them in datasets/.

  • Run the following scripts. The training configuration is in options/train/.

    # RGT-S, input=64x64, 4 GPUs
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_S_x2.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_S_x3.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_S_x4.yml --launcher pytorch
    
    # RGT, input=64x64, 4 GPUs
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_x2.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_x3.yml --launcher pytorch
    python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 basicsr/train.py -opt options/train/train_RGT_x4.yml --launcher pytorch
  • The training experiment is in experiments/.

🔨 Testing

🌗 Test images with HR

  • Download the pre-trained models and place them in experiments/pretrained_models/.

    We provide pre-trained models for image SR: RGT-S and RGT (x2, x3, x4).

  • Download testing (Set5, Set14, BSD100, Urban100, Manga109) datasets, place them in datasets/.

  • Run the following scripts. The testing configuration is in options/test/ (e.g., test_RGT_x2.yml).

    Note 1: You can set use_chop: True (default: False) in YML to chop the image for testing.

    # No self-ensemble
    # RGT-S, reproduces results in Table 2 of the main paper
    python basicsr/test.py -opt options/test/test_RGT_S_x2.yml
    python basicsr/test.py -opt options/test/test_RGT_S_x3.yml
    python basicsr/test.py -opt options/test/test_RGT_S_x4.yml
    
    # RGT, reproduces results in Table 2 of the main paper
    python basicsr/test.py -opt options/test/test_RGT_x2.yml
    python basicsr/test.py -opt options/test/test_RGT_x3.yml
    python basicsr/test.py -opt options/test/test_RGT_x4.yml
  • The output is in results/.

🌓 Test images without HR

  • Download the pre-trained models and place them in experiments/pretrained_models/.

    We provide pre-trained models for image SR: RGT-S and RGT (x2, x3, x4).

  • Put your dataset (single LR images) in datasets/single. Some test images are in this folder.

  • Run the following scripts. The testing configuration is in options/test/ (e.g., test_single_x2.yml).

    Note 1: The default model is RGT. You can use other models like RGT-S by modifying the YML.

    Note 2: You can set use_chop: True (default: False) in YML to chop the image for testing.

    # Test on your dataset
    python basicsr/test.py -opt options/test/test_single_x2.yml
    python basicsr/test.py -opt options/test/test_single_x3.yml
    python basicsr/test.py -opt options/test/test_single_x4.yml
  • The output is in results/.

🔎 Results

We achieved state-of-the-art performance. Detailed results can be found in the paper.

Quantitative Comparison (click to expand)
  • results in Table 2 of the main paper

Visual Comparison (click to expand)
  • results in Figure 6 of the main paper

  • results in Figure 4 of the supplementary material

  • results in Figure 5 of the supplementary material

📎 Citation

If you find the code helpful in your resarch or work, please cite the following paper(s).

@inproceedings{chen2024recursive,
  title={Recursive Generalization Transformer for Image Super-Resolution},
  author={Chen, Zheng and Zhang, Yulun and Gu, Jinjin and Kong, Linghe and Yang, Xiaokang},
  booktitle={ICLR},
  year={2024}
}

💡 Acknowledgements

This code is built on BasicSR.