/SAM-Med2D

Official implementation of SAM-Med2D

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

SAM-Med2D [Paper]

Open in OpenXLab Open In Colab GitHub StarsπŸ”₯πŸ”₯πŸ”₯

🌀️ Highlights

  • πŸ† Collected and curated the largest medical image segmentation dataset (4.6M images and 19.7M masks) to date for training models.
  • πŸ† The most comprehensive fine-tuning based on Segment Anything Model (SAM).
  • πŸ† Comprehensive evaluation of SAM-Med2D on large-scale datasets.

πŸ”₯ Updates

  • (2023.12.05) We open the download of the dataset on the Hugging Face platform
  • (2023.11.23) We have released the SA-Med2D-20M dataset
  • (2023.11.21) We have released article introducing the SA-Med2D-20M dataset
  • (2023.10.24) We now released SAM-Med3D, which focus on segmentation of 3D medical imaging
  • (2023.09.14) Train code release
  • (2023.09.02) Test code release
  • (2023.08.31) Pre-trained model release
  • (2023.08.31) Paper release
  • (2023.08.26) Online Demo release

πŸ‘‰ Dataset

SAM-Med2D is trained and tested on a dataset that includes 4.6M images and 19.7M masks. This dataset covers 10 medical data modalities, 4 anatomical structures + lesions, and 31 major human organs. To our knowledge, this is currently the largest and most diverse medical image segmentation dataset in terms of quantity and coverage of categories.

image

πŸ‘‰ Framework

The pipeline of SAM-Med2D. We freeze the image encoder and incorporate learnable adapter layers in each Transformer block to acquire domain-specific knowledge in the medical field. We fine-tune the prompt encoder using point, Bbox, and mask information, while updating the parameters of the mask decoder through interactive training.

image

πŸ‘‰ Results

Quantitative comparison of different methods on the test set:
Model Resolution Bbox (%) 1 pt (%) 3 pts (%) 5 pts (%) FPS Checkpoint
SAM $256\times256$ 61.63 18.94 28.28 37.47 51 Offical
SAM $1024\times1024$ 74.49 36.88 42.00 47.57 8 Offical
FT-SAM $256\times256$ 73.56 60.11 70.95 75.51 51 FT-SAM
SAM-Med2D $256\times256$ 79.30 70.01 76.35 78.68 35 SAM-Med2D

η™ΎεΊ¦δΊ‘ι“ΎζŽ₯: https://pan.baidu.com/s/1HWo_s8O7r4iQI6irMYU8vQ?pwd=dk5x 提取码: dk5x

Generalization validation on 9 MICCAI2023 datasets, where "*" denotes that we drop adapter layer of SAM-Med2D in test phase:
Datasets Bbox prompt (%) 1 point prompt (%)
SAM SAM-Med2D* SAM-Med2D SAM SAM-Med2D* SAM-Med2D
CrossMoDA23 78.12 86.26 88.42 33.84 65.85 85.26
KiTS23 81.52 86.14 89.89 31.36 56.67 83.71
FLARE23 73.20 77.18 85.09 19.87 32.01 77.17
ATLAS2023 76.98 79.09 82.59 29.07 45.25 64.76
SEG2023 64.82 81.85 85.09 21.15 34.71 72.08
LNQ2023 53.02 57.37 58.01 7.05 7.21 37.64
CAS2023 61.53 78.20 81.10 22.75 46.85 78.46
TDSC-ABUS2023 64.31 69.00 66.14 8.24 18.98 43.55
ToothFairy2023 43.40 39.13 41.23 5.47 5.27 12.93
Weighted sum 73.49 77.67 84.88 20.88 34.30 76.63

Typos of paper

In our original paper, we acknowledge that there were anomalies in the test data presented in Table 4. We have conducted data updates for this project and corrected the values in Table 4. We assure the readers that our research team has recognized this issue and will update Table 4 in the next version. We apologize for any inconvenience this may have caused.

πŸ‘‰ Visualization

image

πŸ‘‰ Train

Prepare your own dataset and refer to the samples in SAM-Med2D/data_demo to replace them according to your specific scenario. You need to generate the image2label_train.json file before running train.py.

If you want to use mixed-precision training, please install Apex. If you don't want to install Apex, you can comment out the line from apex import amp and set use_amp to False.

cd ./SAM-Med2D
python train.py
  • work_dir: Specifies the working directory for the training process. Default value is workdir.
  • image_size: Default value is 256.
  • mask_num: Specify the number of masks corresponding to one image, with a default value of 5.
  • data_path: Dataset directory, for example: data_demo.
  • resume: Pretrained weight file, ignore sam_checkpoint if present.
  • sam_checkpoint: Load sam checkpoint.
  • iter_point: Mask decoder iterative runs.
  • multimask: Determines whether to output multiple masks. Default value is True.
  • encoder_adapter: Whether to fine-tune the Adapter layer, set to False only for fine-tuning the decoder.
  • use_amp: Set whether to use mixed-precision training.

πŸ‘‰ Test

Prepare your own dataset and refer to the samples in SAM-Med2D/data_demo to replace them according to your specific scenario. You need to generate the label2image_test.json file before running test.py.

cd ./SAM-Med2D
python test.py
  • work_dir: Specifies the working directory for the testing process. Default value is workdir.
  • batch_size: 1.
  • image_size: Default value is 256.
  • boxes_prompt: Use Bbox prompt to get segmentation results.
  • point_num: Specifies the number of points. Default value is 1.
  • iter_point: Specifies the number of iterations for point prompts.
  • sam_checkpoint: Load sam or sammed checkpoint.
  • encoder_adapter: Set to True if using SAM-Med2D's pretrained weights.
  • save_pred: Whether to save the prediction results.
  • prompt_path: Is there a fixed Prompt file? If not, the value is None, and it will be automatically generated in the latest prediction.

πŸ‘‰ Deploy

Export to ONNX

  • export encoder model
python3 scripts/export_onnx_encoder_model.py --sam_checkpoint /path/to/sam-med2d_b.pth --output /path/to/sam-med2d_b.encoder.onnx --model-type vit_b --image_size 256 --encoder_adapter True
  • export decoder model
python3 scripts/export_onnx_model.py --checkpoint /path/to/sam-med2d_b.pth --output /path/to/sam-med2d_b.decoder.onnx --model-type vit_b --return-single-mask
  • inference with onnxruntime
# cd examples/SAM-Med2D-onnxruntime
python3 main.py --encoder_model /path/to/sam-med2d_b.encoder.onnx --decoder_model /path/to/sam-med2d_b.decoder.onnx

πŸš€ Try SAM-Med2D

πŸ—“οΈ Ongoing

  • Dataset release
  • Dataset article release
  • Train code release
  • Test code release
  • Pre-trained model release
  • Paper release
  • Online Demo release

🎫 License

This project is released under the Apache 2.0 license.

πŸ’¬ Discussion Group

If you have any questions about SAM-Med2D, please add this WeChat ID to the WeChat group discussion:

image

🀝 Acknowledgement

  • We thank all medical workers and dataset owners for making public datasets available to the community.
  • Thanks to the open-source of the following projects: Segment Anything  

πŸ‘‹ Hiring & Global Collaboration

  • Hiring: We are hiring researchers, engineers, and interns in General Vision Group, Shanghai AI Lab. If you are interested in Medical Foundation Models and General Medical AI, including designing benchmark datasets, general models, evaluation systems, and efficient tools, please contact us.
  • Global Collaboration: We're on a mission to redefine medical research, aiming for a more universally adaptable model. Our passionate team is delving into foundational healthcare models, promoting the development of the medical community. Collaborate with us to increase competitiveness, reduce risk, and expand markets.
  • Contact: Junjun He(hejunjun@pjlab.org.cn), Jin Ye(yejin@pjlab.org.cn), and Tianbin Li (litianbin@pjlab.org.cn).

Reference

@misc{cheng2023sammed2d,
      title={SAM-Med2D}, 
      author={Junlong Cheng and Jin Ye and Zhongying Deng and Jianpin Chen and Tianbin Li and Haoyu Wang and Yanzhou Su and
              Ziyan Huang and Jilong Chen and Lei Jiangand Hui Sun and Junjun He and Shaoting Zhang and Min Zhu and Yu Qiao},
      year={2023},
      eprint={2308.16184},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

@misc{ye2023samed2d20m,
      title={SA-Med2D-20M Dataset: Segment Anything in 2D Medical Imaging with 20 Million masks}, 
      author={Jin Ye and Junlong Cheng and Jianpin Chen and Zhongying Deng and Tianbin Li and Haoyu Wang and Yanzhou Su and Ziyan Huang and Jilong Chen and Lei Jiang and Hui Sun and Min Zhu and Shaoting Zhang and Junjun He and Yu Qiao},
      year={2023},
      eprint={2311.11969},
      archivePrefix={arXiv},
      primaryClass={eess.IV}
}