/RSPrompter

This is the pytorch implement of our paper "RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model"

Primary LanguagePythonApache License 2.0Apache-2.0

RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model



Project Page      arXiv      HFSpace


GitHub stars license arXiv Hugging Face Spaces



English | 简体中文

Introduction

This repository is the code implementation of the paper RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model, which is based on the MMDetection project.

The current branch has been tested under PyTorch 2.x and CUDA 12.1, supports Python 3.7+, and is compatible with most CUDA versions.

If you find this project helpful, please give us a star ⭐️, your support is our greatest motivation.

Main Features
  • A highly consistent API interface and usage method with MMDetection
  • Open source SAM-seg, SAM-det, RSPrompter and other models in the paper
  • Tested with AMP, DeepSpeed and other training methods
  • Support training and testing of multiple datasets

Update Log

🌟 2023.06.29 Released the RSPrompter project, which implements the SAM-seg, SAM-det, RSPrompter and other models in the paper based on Lightning and MMDetection.

🌟 2023.11.25 Updated the code of RSPrompter, which is completely consistent with the API interface and usage method of MMDetection.

🌟 2023.11.26 Added the LoRA efficient fine-tuning method, and made the input image size variable, reducing the memory usage of the model.

🌟 2023.11.26 Provided a reference for the memory usage of each model, see Common Problems for details.

🌟 2023.11.30 Updated the paper content, see Arxiv for details.

TODO

  • Consistent API interface and usage method with MMDetection
  • Reduce the memory usage of the model while ensuring performance by reducing the image input and combining with the large model fine-tuning technology
  • Dynamically variable image size input
  • Efficient fine-tuning method in the model
  • Add SAM-cls model

Table of Contents

Installation

Dependencies

  • Linux or Windows
  • Python 3.7+, recommended 3.10
  • PyTorch 2.0 or higher, recommended 2.1
  • CUDA 11.7 or higher, recommended 12.1
  • MMCV 2.0 or higher, recommended 2.1

Environment Installation

We recommend using Miniconda for installation. The following command will create a virtual environment named rsprompter and install PyTorch and MMCV.

Note: If you have experience with PyTorch and have already installed it, you can skip to the next section. Otherwise, you can follow these steps to prepare.

Step 0: Install Miniconda.

Step 1: Create a virtual environment named rsprompter and activate it.

conda create -n rsprompter python=3.10 -y
conda activate rsprompter

Step 2: Install PyTorch2.1.x.

Linux/Windows:

pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121

Or

conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia

Step 3: Install MMCV2.1.x.

pip install -U openmim
mim install mmcv==2.1.0

Step 4: Install other dependencies.

pip install -U transformers==4.38.1 wandb==0.16.3 einops pycocotools shapely scipy terminaltables importlib peft==0.8.2 mat4py==0.6.0 mpi4py

Step 5: [Optional] Install DeepSpeed.

If you want to use DeepSpeed to train the model, you need to install DeepSpeed. The installation method of DeepSpeed can refer to the DeepSpeed official document.

pip install deepspeed==0.13.4

Note: The support for DeepSpeed under the Windows system is not perfect yet, we recommend that you use DeepSpeed under the Linux system.

Install RSPrompter

Download or clone the RSPrompter repository.

git clone git@github.com:KyanChen/RSPrompter.git
cd RSPrompter

Dataset Preparation

Basic Instance Segmentation Dataset

We provide the instance segmentation dataset preparation method used in the paper.

WHU Building Dataset

  • Image download address: WHU Building Dataset

  • Semantic label to instance label: We provide the corresponding conversion script to convert the semantic label of WHU building dataset to instance label.

NWPU VHR-10 Dataset

SSDD Dataset

Note: In the data folder of this project, we provide the instance labels of the above datasets, which you can use directly.

Organization Method

You can also choose other sources to download the data, but you need to organize the dataset in the following format:

${DATASET_ROOT} # Dataset root directory, for example: /home/username/data/NWPU
├── annotations
│   ├── train.json
│   ├── val.json
│   └── test.json
└── images
    ├── train
    ├── val
    └── test

Note: In the project folder, we provide a folder named data, which contains examples of the organization method of the above datasets.

Other Datasets

If you want to use other datasets, you can refer to MMDetection documentation to prepare the datasets.

Model Training

SAM-based Model

Config File and Main Parameter Parsing

We provide the configuration files of the SAM-based models used in the paper, which can be found in the configs/rsprompter folder. The Config file is completely consistent with the API interface and usage method of MMDetection. Below we provide an analysis of some of the main parameters. If you want to know more about the meaning of the parameters, you can refer to the MMDetection documentation.

Parameter Parsing:

  • work_dir: The output path of model training, which generally does not need to be modified.
  • default_hooks-CheckpointHook: Checkpoint saving configuration during model training, which generally does not need to be modified.
  • default_hooks-visualization: Visualization configuration during model training, comment out during training and uncomment during testing.
  • vis_backends-WandbVisBackend: Configuration of network-side visualization tools, after opening the comment, you need to register an account on the wandb official website, and you can view the visualization results during training in the web browser.
  • num_classes: The number of categories in the dataset, which needs to be modified according to the number of categories in the dataset.
  • prompt_shape: The shape of the Prompt, the first parameter represents $N_p$, and the second parameter represents $K_p$, which generally does not need to be modified.
  • hf_sam_pretrain_name: The name of the SAM model on HuggingFace Spaces, which needs to be modified to your own path, you can use the download script to download.
  • hf_sam_pretrain_ckpt_path: The checkpoint path of the SAM model on HuggingFace Spaces, which needs to be modified to your own path, you can use the download script to download.
  • model-decoder_freeze: Whether to freeze the parameters of the SAM decoder, which generally does not need to be modified.
  • model-neck-feature_aggregator-hidden_channels: The hidden channel number of the feature aggregator, which generally does not need to be modified.
  • model-neck-feature_aggregator-select_layers: The number of layers selected by the feature aggregator, which needs to be modified according to the selected SAM backbone type.
  • model-mask_head-with_sincos: Whether to use sin regularization when predicting prompts, which generally does not need to be modified.
  • dataset_type: The type of dataset, which needs to be modified according to the type of dataset.
  • code_root: Code root directory, modify to the absolute path of the root directory of this project.
  • data_root: Dataset root directory, modify to the absolute path of the dataset root directory.
  • batch_size_per_gpu: Batch size per card, which needs to be modified according to the memory size.
  • resume: Whether to resume training, which generally does not need to be modified.
  • load_from: Checkpoint path of the model's pre-training, which generally does not need to be modified.
  • max_epochs: The maximum number of training rounds, which generally does not need to be modified.
  • runner_type: The type of trainer needs to be consistent with the type of optim_wrapper and strategy, which generally does not need to be modified.

Single Card Training

python tools/train.py configs/rsprompter/xxx.py  # xxx.py is the configuration file you want to use

Multi-card Training

sh ./tools/dist_train.sh configs/rsprompter/xxx.py ${GPU_NUM}  # xxx.py is the configuration file you want to use, GPU_NUM is the number of GPUs used

Other Instance Segmentation Models

If you want to use other instance segmentation models, you can refer to MMDetection to train the models, or you can put their Config files in the configs folder of this project, and then train them according to the above methods.

Model Testing

Single Card Testing:

python tools/test.py configs/rsprompter/xxx.py ${CHECKPOINT_FILE}  # xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use

Multi-card Testing:

sh ./tools/dist_test.sh configs/rsprompter/xxx.py ${CHECKPOINT_FILE} ${GPU_NUM}  # xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, GPU_NUM is the number of GPUs used

Note: If you need to get the visualization results, you can uncomment default_hooks-visualization in the Config file.

Image Prediction

Single Image Prediction:

python demo/image_demo.py ${IMAGE_FILE}  configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR}  # IMAGE_FILE is the image file you want to predict, xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, OUTPUT_DIR is the output path of the prediction result

Multi-image Prediction:

python demo/image_demo.py ${IMAGE_DIR}  configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR}  # IMAGE_DIR is the image folder you want to predict, xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, OUTPUT_DIR is the output path of the prediction result

Common Problems

We have listed some common problems and their corresponding solutions here. If you find that some problems are missing, please feel free to provide PR to enrich this list. If you cannot get help here, please use issue to seek help. Please fill in all the required information in the template, which will help us locate the problem faster.

1. Do I need to install MMDetection?

We recommend that you do not install MMDetection because we have made some modifications to the code of MMDetection, which may cause errors in the code if you install MMDetection. If you encounter an error that the module has not been registered, please check:

  • Whether MMDetection is installed, if so, uninstall it
  • Whether @MODELS.register_module() is added in front of the class name, if not, add it
  • Whether from .xxx import xxx is added in __init__.py, if not, add it
  • Whether custom_imports = dict(imports=['mmdet.rsprompter'], allow_failed_imports=False) is added in the Config file, if not, add it

2. How to evaluate the model after training with DeepSpeed?

We recommend that you use DeepSpeed to train the model because DeepSpeed can greatly improve the training speed of the model. However, the training method of DeepSpeed is different from that of MMDetection, so after using DeepSpeed to train the model, you need to use the method of MMDetection to evaluate it. Specifically, you need to:

  • Convert the model trained by DeepSpeed to the model of MMDetection, enter the folder where the model is stored, and run
python zero_to_fp32.py . $SAVE_CHECKPOINT_NAME -t $CHECKPOINT_DIR  # $SAVE_CHECKPOINT_NAME is the name of the converted model, $CHECKPOINT_DIR is the name of the model trained by DeepSpeed
  • Change runner_type in the Config file to Runner.
  • Use the method of MMDetection to evaluate, and you can get the evaluation results.

3. About resource consumption

Here we list the resource consumption of using different models for your reference.

Model Name Backbone Image Size GPU Batch Size Acceleration Strategy Single Card Memory Usage
SAM-seg (Mask R-CNN) ViT-B/16 1024x1024 1x RTX 4090 24G 8 AMP FP16 19.4 GB
SAM-seg (Mask2Former) ViT-B/16 1024x1024 1x RTX 4090 24G 8 AMP FP16 21.5 GB
SAM-det ResNet50 1024x1024 1x RTX 4090 24G 8 FP32 16.6 GB
RSPrompter-anchor ViT-B/16 1024x1024 1x RTX 4090 24G 2 AMP FP16 20.9 GB
RSPrompter-query ViT-B/16 1024x1024 1x RTX 4090 24G 1 AMP FP16 OOM
RSPrompter-query ViT-B/16 1024x1024 8x NVIDIA A100 40G 1 ZeRO-2 39.6 GB
RSPrompter-anchor ViT-B/16 512x512 8x RTX 4090 24G 4 AMP FP16 20.9 GB
RSPrompter-query ViT-B/16 512x512 8x RTX 4090 24G 2 ZeRO-2 21.1 GB

Note: Low-resolution input images can effectively reduce the memory usage of the model, but their actual performance has not been verified. For details, please refer to Config file.

4. Solution to dist_train.sh: Bad substitution

If you encounter the error Bad substitution when running dist_train.sh, please use bash dist_train.sh to run the script.

5. Unable to access and download the model on HuggingFace Spaces

If you are unable to access and download the model on HuggingFace Spaces, please use the download script to download. Please refer to the official processing method.

Here is the translation into English:

6. The segmentation loss is always 0 or results in NaN (Not a Number)

Due to a small batch size leading to unstable training, there are several different solutions below. You can choose any one of them:

  1. Increase the batch size to 2 or 4 (there might be insufficient GPU memory);

  2. Use the gradient accumulation method (modify the optim_wrapper in the Config file):

optim_wrapper = dict(
    type='AmpOptimWrapper',
    dtype='float16', # Change to 'bfloat16' for more stability
    optimizer=dict(
        type='AdamW',
        lr=base_lr,
        weight_decay=0.05),
    accumulative_counts=4  # Additional configuration needed, change to 4 or other numbers greater than 1
)
  1. Cancel the sine and cosine transformation in the Prompter during decoding (modify with_sincos=False in the Config file);

  2. Use a peft configuration with an input image size of 512 and increase the batch size.

Acknowledgement

This project is developed based on the MMDetection project. Thanks to the developers of the MMDetection project.

Citation

If you use the code or performance benchmarks of this project in your research, please refer to the bibtex below to cite RSPrompter.

@article{chen2024rsprompter,
  title={RSPrompter: Learning to prompt for remote sensing instance segmentation based on visual foundation model},
  author={Chen, Keyan and Liu, Chenyang and Chen, Hao and Zhang, Haotian and Li, Wenyuan and Zou, Zhengxia and Shi, Zhenwei},
  journal={IEEE Transactions on Geoscience and Remote Sensing},
  year={2024},
  publisher={IEEE}
}

License

This project is licensed under the Apache 2.0 license.

Contact

If you have any other questions❓, please contact us in time 👬