/3DSAM-adapter

Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation

Primary LanguagePython

3DSAM-adapter: Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation

Implementation for the paper 3DSAM-adapter: Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation by Shizhan Gong, Yuan Zhong, Wenao Ma, Jinpeng Li, Zhao Wang, Jingyang Zhang, Pheng-Ann Heng, and Qi Dou. Alt text

Details

Despite that the segment anything model (SAM) achieved impressive results on general-purpose semantic segmentation with strong generalization ability on daily images, its demonstrated performance on medical image segmentation is less precise and not stable, especially when dealing with tumor segmentation tasks that involve objects of small sizes, irregular shapes, and low contrast. Notably, the original SAM architecture is designed for 2D natural images, therefore would not be able to extract the 3D spatial information from volumetric medical data effectively. In this paper, we propose a novel adaptation method for transferring SAM from 2D to 3D for promptable medical image segmentation. Through a holistically designed scheme for architecture modification, we transfer the SAM to support volumetric inputs while retaining the majority of its pre-trained parameters for reuse. The fine-tuning process is conducted in a parameter-efficient manner, wherein most of the pre-trained parameters remain frozen, and only a few lightweight spatial adapters are introduced and tuned. Regardless of the domain gap between natural and medical data and the disparity in the spatial arrangement between 2D and 3D, the transformer trained on natural images can effectively capture the spatial patterns present in volumetric medical images with only lightweight adaptations. We conduct experiments on four open-source tumor segmentation datasets, and with a single click prompt, our model can outperform domain state-of-the-art medical image segmentation models on 3 out of 4 tasks, specifically by 8.25%, 29.87%, and 10.11% for kidney tumor, pancreas tumor, colon cancer segmentation, and achieve similar performance for liver tumor segmentation. We also compare our adaptation method with existing popular adapters, and observed significant performance improvement on most datasets.

Datasets

Alt text We use the 4 open-source datasets for training and evaluation our model.

Sample Results

Alt text

Get Started

Main Requirements

  • python=3.9.16
  • cuda=11.3
  • torch==1.12.1
  • torchvision=0.13.1

Installation

We suggest using Anaconda to setup environment on Linux, if you have installed anaconda, you can skip this step.

wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh && zsh Anaconda3-2020.11-Linux-x86_64.sh

Then, we can create environment and install packages using provided requirements.txt

conda create -n med_sam python=3.9.16
conda activate med_sam
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install git+https://github.com/deepmind/surface-distance.git
pip install -r requirements.txt

Our implementation is based on single GPU setting (NVIDIA A40 GPU), but can be easily adapted to use multiple GPUs. We need about 35GB of memory to run.

3DSAM-adapter (Ours)

To use the code, first go to the folder 3DSAM-adapter

cd 3DSAM-adapter

Type the command below to train the 3DSAM-adapter:

python train.py --data kits --snapshot_path "path/to/snapshot/" --data_prefix "path/to/data folder/" 

The pre-trained weight of SAM-B can be downloaded here and shall be put under the folder ckpt. Users with powerful GPUs can also adapt the model with SAM-L or SAM-H.

Type the command below to evaluate the 3DSAM-adapter:

python test.py --data kits --snapshot_path "path/to/snapshot/" --data_prefix "path/to/data folder/"  --num_prompts 1

Using --num_prompts to indicate the number of points used as prompt, the default value is 1.

Our pretrained checkpoint can be downloaded through OneDrive. For all four datasets, the crop size is 128.

Baselines

We provide our implementation for baselines includes

To use the code, first go to the folder baselines

cd baselines

Type the command below to train the baselines:

python train.py --data kits -m swin_unetr --snapshot_path "path/to/snapshot/" --data_prefix "path/to/data folder/"

Using --data to indicate the dataset, can be one of kits, pancreas, lits, colon

Using -m to indicate the method, can be one of swin_unetr, unetr, 3d_uxnet, nnformer, unetr++, transbts

For training Swin-UNETR, download the checkpoint and put it under the folder ckpt.

We use various hyper-parameters for each dataset, for more details, please refer to datasets.py. The crop size is set as (64, 160, 160) for all datasets.

Type the command below to evaluate the performance baselines:

python test.py --data kits -m swin_unetr --snapshot_path "path/to/snapshot/" --data_prefix "path/to/data folder/"

Feedback and Contact

For any questions, please contact szgong22@cse.cuhk.edu.hk

Acknowledgement

Our code is based on Segment-Anything, 3D UX-Net, and Swin UNETR.

Citation

If you find this code useful, please cite in your research papers.

@article{Gong20233DSAMadapterHA,
  title={3DSAM-adapter: Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation},
  author={Shizhan Gong and Yuan Zhong and Wenao Ma and Jinpeng Li and Zhao Wang and Jingyang Zhang and Pheng-Ann Heng and Qi Dou},
  journal={arXiv preprint arXiv:2306.13465},
  year={2023}
}