English | 简体中文
This repository is the code implementation of the paper RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model, which is based on the MMDetection project.
The current branch has been tested under PyTorch 2.x and CUDA 12.1, supports Python 3.7+, and is compatible with most CUDA versions.
If you find this project helpful, please give us a star ⭐️, your support is our greatest motivation.
Main Features
- A highly consistent API interface and usage method with MMDetection
- Open source SAM-seg, SAM-det, RSPrompter and other models in the paper
- Tested with AMP, DeepSpeed and other training methods
- Support training and testing of multiple datasets
🌟 2023.06.29 Released the RSPrompter project, which implements the SAM-seg, SAM-det, RSPrompter and other models in the paper based on Lightning and MMDetection.
🌟 2023.11.25 Updated the code of RSPrompter, which is completely consistent with the API interface and usage method of MMDetection.
🌟 2023.11.26 Added the LoRA efficient fine-tuning method, and made the input image size variable, reducing the memory usage of the model.
🌟 2023.11.26 Provided a reference for the memory usage of each model, see Common Problems for details.
🌟 2023.11.30 Updated the paper content, see Arxiv for details.
- Consistent API interface and usage method with MMDetection
- Reduce the memory usage of the model while ensuring performance by reducing the image input and combining with the large model fine-tuning technology
- Dynamically variable image size input
- Efficient fine-tuning method in the model
- Add SAM-cls model
- Introduction
- Update Log
- TODO
- Table of Contents
- Installation
- Dataset Preparation
- Model Training
- Model Testing
- Image Prediction
- Common Problems
- Acknowledgement
- Citation
- License
- Contact
- Linux or Windows
- Python 3.7+, recommended 3.10
- PyTorch 2.0 or higher, recommended 2.1
- CUDA 11.7 or higher, recommended 12.1
- MMCV 2.0 or higher, recommended 2.1
We recommend using Miniconda for installation. The following command will create a virtual environment named rsprompter
and install PyTorch and MMCV.
Note: If you have experience with PyTorch and have already installed it, you can skip to the next section. Otherwise, you can follow these steps to prepare.
Step 0: Install Miniconda.
Step 1: Create a virtual environment named rsprompter
and activate it.
conda create -n rsprompter python=3.10 -y
conda activate rsprompter
Step 2: Install PyTorch2.1.x.
Linux/Windows:
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
Or
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
Step 3: Install MMCV2.1.x.
pip install -U openmim
mim install mmcv==2.1.0
Step 4: Install other dependencies.
pip install -U transformers==4.38.1 wandb==0.16.3 einops pycocotools shapely scipy terminaltables importlib peft==0.8.2 mat4py==0.6.0 mpi4py
Step 5: [Optional] Install DeepSpeed.
If you want to use DeepSpeed to train the model, you need to install DeepSpeed. The installation method of DeepSpeed can refer to the DeepSpeed official document.
pip install deepspeed==0.13.4
Note: The support for DeepSpeed under the Windows system is not perfect yet, we recommend that you use DeepSpeed under the Linux system.
Download or clone the RSPrompter repository.
git clone git@github.com:KyanChen/RSPrompter.git
cd RSPrompter
We provide the instance segmentation dataset preparation method used in the paper.
-
Image download address: WHU Building Dataset。
-
Semantic label to instance label: We provide the corresponding conversion script to convert the semantic label of WHU building dataset to instance label.
-
Image download address: NWPU VHR-10 Dataset.
-
Instance label download address: NWPU VHR-10 Instance Label.
-
Image download address: SSDD Dataset.
-
Instance label download address: SSDD Instance Label.
Note: In the data
folder of this project, we provide the instance labels of the above datasets, which you can use directly.
You can also choose other sources to download the data, but you need to organize the dataset in the following format:
${DATASET_ROOT} # Dataset root directory, for example: /home/username/data/NWPU
├── annotations
│ ├── train.json
│ ├── val.json
│ └── test.json
└── images
├── train
├── val
└── test
Note: In the project folder, we provide a folder named data
, which contains examples of the organization method of the above datasets.
If you want to use other datasets, you can refer to MMDetection documentation to prepare the datasets.
We provide the configuration files of the SAM-based models used in the paper, which can be found in the configs/rsprompter
folder. The Config file is completely consistent with the API interface and usage method of MMDetection. Below we provide an analysis of some of the main parameters. If you want to know more about the meaning of the parameters, you can refer to the MMDetection documentation.
Parameter Parsing:
-
work_dir
: The output path of model training, which generally does not need to be modified. -
default_hooks-CheckpointHook
: Checkpoint saving configuration during model training, which generally does not need to be modified. -
default_hooks-visualization
: Visualization configuration during model training, comment out during training and uncomment during testing. -
vis_backends-WandbVisBackend
: Configuration of network-side visualization tools, after opening the comment, you need to register an account on thewandb
official website, and you can view the visualization results during training in the web browser. -
num_classes
: The number of categories in the dataset, which needs to be modified according to the number of categories in the dataset. -
prompt_shape
: The shape of the Prompt, the first parameter represents$N_p$ , and the second parameter represents$K_p$ , which generally does not need to be modified. -
hf_sam_pretrain_name
: The name of the SAM model on HuggingFace Spaces, which needs to be modified to your own path, you can use the download script to download. -
hf_sam_pretrain_ckpt_path
: The checkpoint path of the SAM model on HuggingFace Spaces, which needs to be modified to your own path, you can use the download script to download. -
model-decoder_freeze
: Whether to freeze the parameters of the SAM decoder, which generally does not need to be modified. -
model-neck-feature_aggregator-hidden_channels
: The hidden channel number of the feature aggregator, which generally does not need to be modified. -
model-neck-feature_aggregator-select_layers
: The number of layers selected by the feature aggregator, which needs to be modified according to the selected SAM backbone type. -
model-mask_head-with_sincos
: Whether to use sin regularization when predicting prompts, which generally does not need to be modified. -
dataset_type
: The type of dataset, which needs to be modified according to the type of dataset. -
code_root
: Code root directory, modify to the absolute path of the root directory of this project. -
data_root
: Dataset root directory, modify to the absolute path of the dataset root directory. -
batch_size_per_gpu
: Batch size per card, which needs to be modified according to the memory size. -
resume
: Whether to resume training, which generally does not need to be modified. -
load_from
: Checkpoint path of the model's pre-training, which generally does not need to be modified. -
max_epochs
: The maximum number of training rounds, which generally does not need to be modified. -
runner_type
: The type of trainer needs to be consistent with the type ofoptim_wrapper
andstrategy
, which generally does not need to be modified.
python tools/train.py configs/rsprompter/xxx.py # xxx.py is the configuration file you want to use
sh ./tools/dist_train.sh configs/rsprompter/xxx.py ${GPU_NUM} # xxx.py is the configuration file you want to use, GPU_NUM is the number of GPUs used
If you want to use other instance segmentation models, you can refer to MMDetection to train the models, or you can put their Config files in the configs
folder of this project, and then train them according to the above methods.
python tools/test.py configs/rsprompter/xxx.py ${CHECKPOINT_FILE} # xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use
sh ./tools/dist_test.sh configs/rsprompter/xxx.py ${CHECKPOINT_FILE} ${GPU_NUM} # xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, GPU_NUM is the number of GPUs used
Note: If you need to get the visualization results, you can uncomment default_hooks-visualization
in the Config file.
python demo/image_demo.py ${IMAGE_FILE} configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR} # IMAGE_FILE is the image file you want to predict, xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, OUTPUT_DIR is the output path of the prediction result
python demo/image_demo.py ${IMAGE_DIR} configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR} # IMAGE_DIR is the image folder you want to predict, xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, OUTPUT_DIR is the output path of the prediction result
We have listed some common problems and their corresponding solutions here. If you find that some problems are missing, please feel free to provide PR to enrich this list. If you cannot get help here, please use issue to seek help. Please fill in all the required information in the template, which will help us locate the problem faster.
1. Do I need to install MMDetection?
We recommend that you do not install MMDetection because we have made some modifications to the code of MMDetection, which may cause errors in the code if you install MMDetection. If you encounter an error that the module has not been registered, please check:
- Whether MMDetection is installed, if so, uninstall it
- Whether
@MODELS.register_module()
is added in front of the class name, if not, add it - Whether
from .xxx import xxx
is added in__init__.py
, if not, add it - Whether
custom_imports = dict(imports=['mmdet.rsprompter'], allow_failed_imports=False)
is added in the Config file, if not, add it
2. How to evaluate the model after training with DeepSpeed?
We recommend that you use DeepSpeed to train the model because DeepSpeed can greatly improve the training speed of the model. However, the training method of DeepSpeed is different from that of MMDetection, so after using DeepSpeed to train the model, you need to use the method of MMDetection to evaluate it. Specifically, you need to:
- Convert the model trained by DeepSpeed to the model of MMDetection, enter the folder where the model is stored, and run
python zero_to_fp32.py . $SAVE_CHECKPOINT_NAME -t $CHECKPOINT_DIR # $SAVE_CHECKPOINT_NAME is the name of the converted model, $CHECKPOINT_DIR is the name of the model trained by DeepSpeed
- Change
runner_type
in the Config file toRunner
. - Use the method of MMDetection to evaluate, and you can get the evaluation results.
3. About resource consumption
Here we list the resource consumption of using different models for your reference.
Model Name | Backbone | Image Size | GPU | Batch Size | Acceleration Strategy | Single Card Memory Usage |
---|---|---|---|---|---|---|
SAM-seg (Mask R-CNN) | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 8 | AMP FP16 | 19.4 GB |
SAM-seg (Mask2Former) | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 8 | AMP FP16 | 21.5 GB |
SAM-det | ResNet50 | 1024x1024 | 1x RTX 4090 24G | 8 | FP32 | 16.6 GB |
RSPrompter-anchor | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 2 | AMP FP16 | 20.9 GB |
RSPrompter-query | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 1 | AMP FP16 | OOM |
RSPrompter-query | ViT-B/16 | 1024x1024 | 8x NVIDIA A100 40G | 1 | ZeRO-2 | 39.6 GB |
RSPrompter-anchor | ViT-B/16 | 512x512 | 8x RTX 4090 24G | 4 | AMP FP16 | 20.9 GB |
RSPrompter-query | ViT-B/16 | 512x512 | 8x RTX 4090 24G | 2 | ZeRO-2 | 21.1 GB |
Note: Low-resolution input images can effectively reduce the memory usage of the model, but their actual performance has not been verified. For details, please refer to Config file.
4. Solution to dist_train.sh: Bad substitution
If you encounter the error Bad substitution
when running dist_train.sh
, please use bash dist_train.sh
to run the script.
5. Unable to access and download the model on HuggingFace Spaces
If you are unable to access and download the model on HuggingFace Spaces, please use the download script to download. Please refer to the official processing method.
Here is the translation into English:
6. The segmentation loss is always 0 or results in NaN (Not a Number)
Due to a small batch size leading to unstable training, there are several different solutions below. You can choose any one of them:
-
Increase the batch size to 2 or 4 (there might be insufficient GPU memory);
-
Use the gradient accumulation method (modify the
optim_wrapper
in the Config file):
optim_wrapper = dict(
type='AmpOptimWrapper',
dtype='float16', # Change to 'bfloat16' for more stability
optimizer=dict(
type='AdamW',
lr=base_lr,
weight_decay=0.05),
accumulative_counts=4 # Additional configuration needed, change to 4 or other numbers greater than 1
)
-
Cancel the sine and cosine transformation in the Prompter during decoding (modify
with_sincos=False
in the Config file); -
Use a peft configuration with an input image size of 512 and increase the batch size.
This project is developed based on the MMDetection project. Thanks to the developers of the MMDetection project.
If you use the code or performance benchmarks of this project in your research, please refer to the bibtex below to cite RSPrompter.
@article{chen2024rsprompter,
title={RSPrompter: Learning to prompt for remote sensing instance segmentation based on visual foundation model},
author={Chen, Keyan and Liu, Chenyang and Chen, Hao and Zhang, Haotian and Li, Wenyuan and Zou, Zhengxia and Shi, Zhenwei},
journal={IEEE Transactions on Geoscience and Remote Sensing},
year={2024},
publisher={IEEE}
}
This project is licensed under the Apache 2.0 license.
If you have any other questions❓, please contact us in time 👬