/MST

"Mask-guided Spectral-wise Transformer for Efficient Hyperspectral Image Reconstruction" (CVPR 2022) and a Baseline for Spectral Compressive Imaging including our work HDNet, CST, and DAUHST

Primary LanguagePython

Mask-guided Spectral-wise Transformer for Efficient Hyperspectral Image Reconstruction (CVPR 2022)

winner arXiv zhihu visitors

Yuanhao Cai, Jing Lin, Xiaowan Hu, Haoqian Wang, Xin Yuan, Yulun Zhang, Radu Timofte, and Luc Van Gool

The first two authors contribute equally to this work

ntire

News

  • 2022.07.04 : Our paper CST has been accepted by ECCV 2022, code and models are coming soon. 🚀
  • 2022.06.14 : Code and models of MST and MST++ have been released. This repo supports 11 learning-based methods to serve as toolbox for Spectral Compressive Imaging. The model zoo will be enlarged. 🔥
  • 2022.05.20 : Our work DAUHST is on arxiv. 💫
  • 2022.04.02 : Further work MST++ has won the NTIRE 2022 Spectral Reconstruction Challenge. 🏆
  • 2022.03.09 : Our work CST is on arxiv. 💫
  • 2022.03.02 : Our paper MST has been accepted by CVPR 2022, code and models are coming soon. 🚀
Scene 2 Scene 3 Scene 4 Scene 7

Abstract: Hyperspectral image (HSI) reconstruction aims to recover the 3D spatial-spectral signal from a 2D measurement in the coded aperture snapshot spectral imaging (CASSI) system. The HSI representations are highly similar and correlated across the spectral dimension. Modeling the inter-spectra interactions is beneficial for HSI reconstruction. However, existing CNN-based methods show limitations in capturing spectral-wise similarity and long-range dependencies. Besides, the HSI information is modulated by a coded aperture (physical mask) in CASSI. Nonetheless, current algorithms have not fully explored the guidance effect of the mask for HSI restoration. In this paper, we propose a novel framework, Mask-guided Spectral-wise Transformer (MST), for HSI reconstruction. Specifically, we present a Spectral-wise Multi-head Self-Attention (S-MSA) that treats each spectral feature as a token and calculates self-attention along the spectral dimension. In addition, we customize a Mask-guided Mechanism (MM) that directs S-MSA to pay attention to spatial regions with high-fidelity spectral representations. Extensive experiments show that our MST significantly outperforms state-of-the-art (SOTA) methods on simulation and real HSI datasets while requiring dramatically cheaper computational and memory costs.


Diagram of Our Method

Illustration of MST

Comparison with State-of-the-art Methods

This repo is a baseline and toolbox containing 11 learning-based algorithms for spectral compressive imaging.

We are going to enlarge our model zoo in the future.

Supported algorithms:

comparison_fig

Quantitative Comparison on Simulation Dataset

Method Params (M) FLOPS (G) PSNR SSIM Model Zoo Simulation Result Real Result
λ-Net 62.64 117.98 28.53 0.841 Google Drive / Baidu Disk Google Drive / Baidu Disk
TSA-Net 44.25 110.06 31.46 0.894 Google Drive / Baidu Disk Google Drive / Baidu Disk
DGSMP 3.76 646.65 32.63 0.917 Google Drive / Baidu Disk Google Drive / Baidu Disk
GAP-Net 4.27 78.58 33.26 0.917 Google Drive / Baidu Disk Google Drive / Baidu Disk
ADMM-Net 4.27 78.58 33.58 0.918 Google Drive / Baidu Disk Google Drive / Baidu Disk
BIRNAT 4.40 2122.66 37.58 0.960 Google Drive / Baidu Disk Google Drive / Baidu Disk
HDNet 2.37 154.76 34.97 0.943 Google Drive / Baidu Disk Google Drive / Baidu Disk Google Drive / Baidu Disk
MST-S 0.93 12.96 34.26 0.935 Google Drive / Baidu Disk Google Drive / Baidu Disk Google Drive / Baidu Disk
MST-M 1.50 18.07 34.94 0.943 Google Drive / Baidu Disk Google Drive / Baidu Disk Google Drive / Baidu Disk
MST-L 2.03 28.15 35.18 0.948 Google Drive / Baidu Disk Google Drive / Baidu Disk Google Drive / Baidu Disk
MST++ 1.33 19.42 35.99 0.951 Google Drive / Baidu Disk Google Drive / Baidu Disk Google Drive / Baidu Disk

The performance are reported on 10 scenes of the KAIST dataset. The test size of FLOPS is 256 x 256.

Note: access code for Baidu Disk is mst1

1. Create Environment:

  • Python 3 (Recommend to use Anaconda)

  • NVIDIA GPU + CUDA

  • Python packages:

pip install -r requirements.txt

2. Prepare Dataset:

Download cave_1024_28 (One Drive), CAVE_512_28 (Baidu Disk, code: ixoe | One Drive), KAIST_CVPR2021 (Baidu Disk, code: 5mmn | One Drive), TSA_simu_data (One Drive), TSA_real_data (One Drive), and then put them into the corresponding folders of datasets/ and recollect them as the following form:

|--MST
    |--real
    	|-- test_code
    	|-- train_code
    |--simulation
    	|-- test_code
    	|-- train_code
    |--visualization
    |--datasets
        |--cave_1024_28
            |--scene1.mat
            |--scene2.mat
            :  
            |--scene205.mat
        |--CAVE_512_28
            |--scene1.mat
            |--scene2.mat
            :  
            |--scene30.mat
        |--KAIST_CVPR2021  
            |--1.mat
            |--2.mat
            : 
            |--30.mat
        |--TSA_simu_data  
            |--mask.mat   
            |--Truth
                |--scene01.mat
                |--scene02.mat
                : 
                |--scene10.mat
        |--TSA_real_data  
            |--mask.mat   
            |--Measurements
                |--scene1.mat
                |--scene2.mat
                : 
                |--scene5.mat

Following TSA-Net and DGSMP, we use the CAVE dataset (cave_1024_28) as the simulation training set. Both the CAVE (CAVE_512_28) and KAIST (KAIST_CVPR2021) datasets are used as the real training set.

3. Simulation Experiement:

(1) Training:

cd MST/simulation/train_code/

# MST_S
python train.py --template mst_s --outf ./exp/mst_s/ --method mst_s 

# MST_M
python train.py --template mst_m --outf ./exp/mst_m/ --method mst_m  

# MST_L
python train.py --template mst_l --outf ./exp/mst_l/ --method mst_l 

# GAP-Net
python train.py --template gap_net --outf ./exp/gap_net/ --method gap_net 

# ADMM-Net
python train.py --template admm_net --outf ./exp/admm_net/ --method admm_net 

# TSA-Net
python train.py --template tsa_net --outf ./exp/tsa_net/ --method tsa_net 

# HDNet
python train.py --template hdnet --outf ./exp/hdnet/ --method hdnet 

# DGSMP
python train.py --template dgsmp --outf ./exp/dgsmp/ --method dgsmp 

# BIRNAT
python train.py --template birnat --outf ./exp/birnat/ --method birnat 

# MST_Plus_Plus
python train.py --template mst_plus_plus --outf ./exp/mst_plus_plus/ --method mst_plus_plus 

# λ-Net
python train.py --template lambda_net --outf ./exp/lambda_net/ --method lambda_net

The training log, trained model, and reconstrcuted HSI will be available in MST/simulation/train_code/exp/ .

(2) Testing :

Download the pretrained model zoo from (Google Drive / Baidu Disk, code: mst1) and place them to MST/simulation/test_code/model_zoo/

Run the following command to test the model on the simulation dataset.

cd MST/simulation/test_code/

# MST_S
python test.py --template mst_s --outf ./exp/mst_s/ --method mst_s --pretrained_model_path ./model_zoo/mst/mst_s.pth

# MST_M
python test.py --template mst_m --outf ./exp/mst_m/ --method mst_m --pretrained_model_path ./model_zoo/mst/mst_m.pth

# MST_L
python test.py --template mst_l --outf ./exp/mst_l/ --method mst_l --pretrained_model_path ./model_zoo/mst/mst_l.pth

# GAP_Net
python test.py --template gap_net --outf ./exp/gap_net/ --method gap_net --pretrained_model_path ./model_zoo/gap_net/gap_net.pth

# ADMM_Net
python test.py --template admm_net --outf ./exp/admm_net/ --method admm_net --pretrained_model_path ./model_zoo/admm_net/admm_net.pth

# TSA_Net
python test.py --template tsa_net --outf ./exp/tsa_net/ --method tsa_net --pretrained_model_path ./model_zoo/tsa_net/tsa_net.pth

# HDNet
python test.py --template hdnet --outf ./exp/hdnet/ --method hdnet --pretrained_model_path ./model_zoo/hdnet/hdnet.pth

# DGSMP
python test.py --template dgsmp --outf ./exp/dgsmp/ --method dgsmp --pretrained_model_path ./model_zoo/dgsmp/dgsmp.pth

# BIRNAT
python test.py --template birnat --outf ./exp/birnat/ --method birnat --pretrained_model_path ./model_zoo/birnat/birnat.pth

# MST_Plus_Plus
python test.py --template mst_plus_plus --outf ./exp/mst_plus_plus/ --method mst_plus_plus --pretrained_model_path ./model_zoo/mst_plus_plus/mst_plus_plus.pth

# λ-Net
python test.py --template lambda_net --outf ./exp/lambda_net/ --method lambda_net --pretrained_model_path ./model_zoo/lambda_net/lambda_net.pth
  • The reconstrcuted HSIs will be output into MST/simulation/test_code/exp/

  • Place the reconstructed results into MST/simulation/test_code/Quality_Metrics/results and

Run cal_quality_assessment.m

to calculate the PSNR and SSIM of the reconstructed HSIs.

(3) Visualization :

  • Put the reconstruted HSI in MST/visualization/simulation_results/results and rename it as method.mat, e.g., mst_s.mat.

  • Generate the RGB images of the reconstructed HSIs

 cd MST/visualization/
 Run show_simulation.m 
  • Draw the spetral density lines
cd MST/visualization/
Run show_line.m

4. Real Experiement:

(1) Training:

cd MST/real/train_code/

# MST_S
python train.py --template mst_s --outf ./exp/mst_s/ --method mst_s 

# MST_M
python train.py --template mst_m --outf ./exp/mst_m/ --method mst_m  

# MST_L
python train.py --template mst_l --outf ./exp/mst_l/ --method mst_l 

# GAP-Net
python train.py --template gap_net --outf ./exp/gap_net/ --method gap_net 

# ADMM-Net
python train.py --template admm_net --outf ./exp/admm_net/ --method admm_net 

# TSA-Net
python train.py --template tsa_net --outf ./exp/tsa_net/ --method tsa_net 

# HDNet
python train.py --template hdnet --outf ./exp/hdnet/ --method hdnet 

# DGSMP
python train.py --template dgsmp --outf ./exp/dgsmp/ --method dgsmp 

# BIRNAT
python train.py --template birnat --outf ./exp/birnat/ --method birnat 

# MST_Plus_Plus
python train.py --template mst_plus_plus --outf ./exp/mst_plus_plus/ --method mst_plus_plus 

# λ-Net
python train.py --template lambda_net --outf ./exp/lambda_net/ --method lambda_net

The training log, trained model, and reconstrcuted HSI will be available in MST/real/train_code/exp/ .

(2) Testing :

cd MST/real/test_code/

# MST_S
python test.py --template mst_s --outf ./exp/mst_s/ --method mst_s --pretrained_model_path ./model_zoo/mst/mst_s.pth

# MST_M
python test.py --template mst_m --outf ./exp/mst_m/ --method mst_m --pretrained_model_path ./model_zoo/mst/mst_m.pth

# MST_L
python test.py --template mst_l --outf ./exp/mst_l/ --method mst_l --pretrained_model_path ./model_zoo/mst/mst_l.pth

# GAP_Net
python test.py --template gap_net --outf ./exp/gap_net/ --method gap_net --pretrained_model_path ./model_zoo/gap_net/gap_net.pth

# ADMM_Net
python test.py --template admm_net --outf ./exp/admm_net/ --method admm_net --pretrained_model_path ./model_zoo/admm_net/admm_net.pth

# TSA_Net
python test.py --template tsa_net --outf ./exp/tsa_net/ --method tsa_net --pretrained_model_path ./model_zoo/tsa_net/tsa_net.pth

# HDNet
python test.py --template hdnet --outf ./exp/hdnet/ --method hdnet --pretrained_model_path ./model_zoo/hdnet/hdnet.pth

# DGSMP
python test.py --template dgsmp --outf ./exp/dgsmp/ --method dgsmp --pretrained_model_path ./model_zoo/dgsmp/dgsmp.pth

# BIRNAT
python test.py --template birnat --outf ./exp/birnat/ --method birnat --pretrained_model_path ./model_zoo/birnat/birnat.pth

# MST_Plus_Plus
python test.py --template mst_plus_plus --outf ./exp/mst_plus_plus/ --method mst_plus_plus --pretrained_model_path ./model_zoo/mst_plus_plus/mst_plus_plus.pth

# λ-Net
python test.py --template lambda_net --outf ./exp/lambda_net/ --method lambda_net --pretrained_model_path ./model_zoo/lambda_net/lambda_net.pth
  • The reconstrcuted HSI will be output into MST/real/test_code/exp/

(3) Visualization :

  • Put the reconstruted HSI in MST/visualization/real_results/results and rename it as method.mat, e.g., mst_plus_plus.mat.

  • Generate the RGB images of the reconstructed HSI

cd MST/visualization/
Run show_real.m

Citation

If this repo helps you, please consider citing our works:

# MST
@inproceedings{mst,
  title={Mask-guided Spectral-wise Transformer for Efficient Hyperspectral Image Reconstruction},
  author={Yuanhao Cai and Jing Lin and Xiaowan Hu and Haoqian Wang and Xin Yuan and Yulun Zhang and Radu Timofte and Luc Van Gool},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2022}
}


# MST++
@inproceedings{mst_pp,
  title={MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction},
  author={Yuanhao Cai and Jing Lin and Zudi Lin and Haoqian Wang and Yulun Zhang and Hanspeter Pfister and Radu Timofte and Luc Van Gool},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
  year={2022}
}


# HDNet
@inproceedings{hdnet,
  title={HDNet: High-resolution Dual-domain Learning for Spectral Compressive Imaging},
  author={Xiaowan Hu and Yuanhao Cai and Jing Lin and  Haoqian Wang and Xin Yuan and Yulun Zhang and Radu Timofte and Luc Van Gool},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2022}
}


# CST
@inproceedings{cst,
  title={Coarse-to-Fine Sparse Transformer for Hyperspectral Image Reconstruction},
  author={Yuanhao Cai and Jing Lin and Xiaowan Hu and Haoqian Wang and Xin Yuan and Yulun Zhang and Radu Timofte and Luc Van Gool},
  booktitle={European Conference on Computer Vision (ECCV)},
  year={2022}
}