/AnatoMask

[ECCV2024] AnatoMask: Enhancing Medical Image Segmentation with Reconstruction-guided Self-masking. Official Pytorch Implementation of AnatoMask.

Primary LanguagePythonMIT LicenseMIT

AnatoMask: Enhancing Medical Image Segmentation with Reconstruction-guided Self-masking

arxiv

Linux-version for AnatoMask

Please clone our repo and install nnUNetv2 from our source.

pip install -e .

Please also ensure setting up your system variables correctly following nnUNet's instructions: instructions

Other packages that we use: cuda==12.1 torch==2.0.1 simpleitk==2.3.1

What is AnatoMask?

Our hypothesis is pretty simple: masked image modeling + ConvNet backbone = success for medical image segmentation.

Given nnUNet's state-of-the-art performance, we want to offer the option to conduct self-supervised pretraining leveraging nnUNet's whole pipeline.

Currently, we offer the option to use 1). SparK, which is the CNN equivalent of masked autoencoders and 2). AnatoMask, which refines SparK by bootstrapping difficult regions to form more difficult pretraining masks. After pretraining on some dataset, we can transfer these weights for downstream segmentation tasks.

Check out this comparison: Comparison with random masking

Currently, our backbones are all CNNs! This ensures optimal performance for segmentation :)

Pretraining using AnatoMask

Step 1: Prepare your segmentation model's encoder. An example is given for STUNet_head.py.

For more info on building your own CNN encoder, refer to SparK's guideline

Step 2: Go to ssl_pretrain

A few more things to do:

  • Set up your output_folder = 'XXX' This contains your saved model weights.
  • Set up your preprocessed_dataset_folder = 'XXX/nnUNet_preprocessed/Dataset009_Spleen/nnUNetPlans_3d_fullres' This is your preprocessed nnUNet dataset. Please be sure to preprocess your dataset first following nnUNet's tutorial!
  • Find your nnUNet splits file (or create your own split if you are so inclined). splits_file = 'XXX/nnUNet_preprocessed/Dataset009_Spleen/splits_final.json' You can get this by running nnUNet once on your dataset.
  • Find your dataset json file: dataset_json = load_json('XXX/Dataset009_Spleen/dataset.json')
  • Find your plan json file: plans = load_json('XXX/nnUNet_preprocessed/Dataset009_Spleen/nnUNetPlans.json')
  • Run: python pretrain_AnatoMask.py

Note: You can use SparK by following the same steps and run pretrain.py

Finetuning

Define your function to load pretrained weights here (an example is given) EXAMPLE

Import your function and replace nnUNet's load_pretrained_weights HERE

Finally, run your nnUNet training command as usual, but adding -pretrained_weights PATH_TO_YOUR_WEIGHTS

Our workflow currently supports STUNetTrainer -> HERE

If you want to use your own model, write your own trainer class following STUNetTrainer's example.

What exactly does AnatoMask do?

We propose a reconstruction-guided masking strategy, so that the model learns the anatomically significant regions through reconstruction losses. This is done by using self-distillation. Basically, a teacher network first identifies important regions to mask and generates a more difficult mask for the student to solve.
To prevent the network from converging to a suboptimal solution early during training, we use an easy-to-hard a masking dynamics function controlling the difficulty of the MIM objective. Overview

Model Zoo

TBC

TO DO

  • Release pretrained weights and finetuned weights.

Reference

Please cite here when using AnatoMask:

Li, Y., Luan, T., Wu, Y., Pan, S., Chen, Y., & Yang, X. (2024). AnatoMask: Enhancing Medical Image Segmentation with Reconstruction-guided Self-masking. arXiv preprint arXiv:2407.06468.