/PA-Seg

PA-Seg: Learning from Point Annotations for 3D Medical Image Segmentation using Contextual Regularization and Cross Knowledge Distillation

Primary LanguagePythonMIT LicenseMIT

PA-Seg

Code for this paper: PA-Seg: Learning from Point Annotations for 3D Medical Image Segmentation using Contextual Regularization and Cross Knowledge Distillation (TMI2023)

@article{zhai2023pa,
        title={PA-Seg: Learning from Point Annotations for 3D Medical Image Segmentation using Contextual Regularization and Cross Knowledge Distillation},
        author={Zhai, Shuwei and Wang, Guotai and Luo, Xiangde and Yue, Qiang and Li, Kang and Zhang, Shaoting},
        journal={IEEE Transactions on Medical Imaging},
        year={2023},
        publisher={IEEE}
}

Method Overview

PA-Seg trains a brain tumor segmentation model using point annotations. Each 3D image has seven annotated points: one in the foreground and six in the background, respectively. In the first stage, we expand the annotation seeds based on geodesic distance transform, and train an initial model using the expanded seeds, with the unlabeled pixels regularized by multi-view CRF loss and Variance Minimization (VM) loss. Pseudo labels are then obtained by using the initial model for inference. In the second stage, to deal with noises in the pseudo labels, we propose Self and Cross Monitoring (SCM), where a primary model and an auxiliary model supervise each other via Cross Knowledge Distillation (CKD) based on soft labels, in addition to self-training of each model.

*Illustration of our point annotation-based segmentation. Green: Background. Red: Foreground.*

*An overview of PA-Seg for weakly supervised 3D segmentation based on point annotations.*

Requirements

We have only tested in the following environments. Please ensure that the version of each package is not lower than that listed below.

  • Set up a virtual environment (e.g. conda or virtualenv) with Python == 3.8.10
  • Follow official guidance to install Pytorch with torch == 1.9.1+cu111
  • Install other requirements using:
pip install -r requirements.txt

Data

Vestibular Schwannoma (VS) Segmentation

Step 1: Please follow the instructions here to download the NBIA Data Retriever.

Step 2: Open manifest-T2.tcia with NBIA Data Retriever and download the T2 images (DICOM, 6GB) with the "Descriptive Directory Name" format. The save path is set to ./data .

Step 3: Execute the script to convert DICOM to Nifti:

python ./data/VS/convert.py \
--input ./data/manifest-T2/Vestibular-Schwannoma-SEG \
--output ./data/VS/image

Step 4: Download full annotations from Google Drive or ALiYun Drive and save them in ./data/VS/label .

Step 5: Execute the script to crop the images and full annotations:

python ./data/VS/image_crop.py \
--data_dir ./data/VS/ \
--dataset_split ./splits/split_VS.csv \
--image_postfix T2 \
--label_postfix Label

Step 6: Download our proposed point annotations from Google Drive or ALiYun Drive and save them in ./data/VS/annotation_7points . Note that the point annotations have already been cropped.

Brain Tumor Segmentation (BraTS)

Step 1: Please follow the instructions here to acquire the training and the validation data of the BraTS 2019 challenge. Put the dataset in the directory ./data

Step 2: Execute the script to merge the three labels "enhancing tumor", "tumor core", and "whole tumor" into a single label:

python ./data/BraTS/merge_and_move.py \
--original_dir ./data/BraTS2019/MICCAI_BraTS_2019_Data_Training/ \
--destination_dir ./data/BraTS/

Step 3: Download our proposed point annotations from Google Drive or ALiYun Drive and save them in ./data/BraTS/annotation_7points

Usage

The First Stage of PA-Seg: Learning from Point Annotation

Execute the script to expand the annotation seeds based on geodesic distance transform:

# VS dataset
python ./data/VS/generate_geodesic_labels.py \
--dataset_split ./splits/split_VS.csv \
--path_images ./data/VS/image_crop/ \
--image_postfix T2 \
--path_labels ./data/VS/label_crop/ \
--label_postfix Label \
--path_anno_7points ./data/VS/annotation_7points/ \
--anno_7points_postfix 7points \
--path_geodesic ./data/VS/geodesic/ \
--geodesic_weight 0.5 \
--geodesic_threshold 0.2

# BraTS dataset
python ./data/BraTS/generate_geodesic_labels.py \
--dataset_split ./splits/split_BraTS.csv \
--path_images ./data/BraTS/image/ \
--image_postfix Flair \
--path_labels ./data/BraTS/label/ \
--label_postfix Label \
--path_anno_7points ./data/BraTS/annotation_7points/ \
--anno_7points_postfix 7points \
--path_geodesic ./data/BraTS/geodesic/ \
--geodesic_weight 0.5 \
--geodesic_threshold 0.2

Execute the script to train an initial model using the expanded seeds, with the unlabeled pixels regularized by multi-view CRF loss and Variance Minimization (VM) loss. Note that the --network parameter has three options: U-Net2D5, U-Net, and AttU_Net.

# VS dataset
python train_gatedcrfloss3d22d_multiview_varianceloss.py \
--model_dir ./models/VS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network U_Net2D5 \
--batch_size 1 \
--max_epochs 300 \
--rampup_epochs 30 \
--dataset_split ./splits/split_VS.csv \
--path_images ./data/VS/image_crop/ \
--image_postfix T2 \
--path_labels ./data/VS/label_crop/ \
--label_postfix Label \
--path_geodesic_labels ./data/VS/geodesic/weight0.5_threshold0.2/geodesic_label/ \
--geodesic_label_postfix GeodesicLabel \
--learning_rate 1e-2 \
--spatial_shape 128 128 48 \
--weight_gatedcrf 1.0 \
--down_size 64 64 48 \
--kernel_radius 5 5 -1 \
--weight_variance 0.1

# BraTS dataset
python train_gatedcrfloss3d22d_multiview_varianceloss.py \
--model_dir ./models/BraTS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network U_Net \
--batch_size 1 \
--max_epochs 500 \
--rampup_epochs 50 \
--dataset_split ./splits/split_BraTS.csv \
--path_images ./data/BraTS/image/ \
--image_postfix Flair \
--path_labels ./data/BraTS/label/ \
--label_postfix Label \
--path_geodesic_labels ./data/BraTS/geodesic/weight0.5_threshold0.2/geodesic_label/ \
--geodesic_label_postfix GeodesicLabel \
--learning_rate 1e-2 \
--spatial_shape 128 128 128 \
--weight_gatedcrf 1.0 \
--down_size 64 64 64 \
--kernel_radius 5 5 -1 \
--weight_variance 0.1

Perform inference on the test dataset using inference.py to obtain segmentation results. Then, execute utilities/scores.py to obtain evaluation metrics such as "dice" and "assd" for the segmentation results.

# VS dataset
python inference.py \
--model_dir ./models/VS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network U_Net2D5 \
--dataset_split ./splits/split_VS.csv \
--path_images ./data/VS/image_crop/ \
--image_postfix T2 \
--phase inference \
--spatial_shape 128 128 48 \
--epoch_inf best

python utilities/scores.py \
--model_dir ./models/VS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network U_Net2D5 \
--dataset_split ./splits/split_VS.csv \
--image_postfix T2 \
--phase inference \
--path_labels ./data/VS/label_crop/ \
--label_postfix Label

# BraTS dataset
python inference.py \
--model_dir ./models/BraTS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network U_Net \
--dataset_split ./splits/split_BraTS.csv \
--path_images ./data/BraTS/image/ \
--image_postfix Flair \
--phase inference \
--spatial_shape 128 128 128 \
--epoch_inf best

python utilities/scores.py \
--model_dir ./models/BraTS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network U_Net \
--dataset_split ./splits/split_BraTS.csv \
--image_postfix Flair \
--phase inference \
--path_labels ./data/BraTS/label/ \
--label_postfix Label

The Second Stage of PA-Seg: Noise-robust learning using Self and Cross Monitoring (SCM)

Execute the script to train enhanced models using SCM:

# VS dataset
python train_SCM.py \
--pretrained_model1_dir ./models/VS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network1 U_Net2D5 \
--pretrained_model2_dir ./models/VS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network2 AttU_Net \
--model_dir ./models/VS/SCM/ \
--batch_size 1 \
--max_epochs 100 \
--iterative_epochs 20 \
--dataset_split ./splits/split_VS.csv \
--path_images ./data/VS/image_crop/ \
--image_postfix T2 \
--path_labels ./data/VS/label_crop/ \
--label_postfix Label \
--learning_rate 1e-2 \
--spatial_shape 128 128 48 \
--weight_kd 0.5 \
--T 4.0

# BraTS dataset
python train_SCM.py \
--pretrained_model1_dir ./models/BraTS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network1 AttU_Net \
--pretrained_model2_dir ./models/BraTS/gatedcrfloss3d22d_multiview_varianceloss/ \
--network2 U_Net \
--model_dir ./models/BraTS/SCM/ \
--batch_size 1 \
--max_epochs 100 \
--iterative_epochs 20 \
--dataset_split ./splits/split_BraTS.csv \
--path_images ./data/BraTS/image/ \
--image_postfix Flair \
--path_labels ./data/BraTS/label/ \
--label_postfix Label \
--learning_rate 1e-2 \
--spatial_shape 128 128 128 \
--weight_kd 0.5 \
--T 4.0

Run the inference using the main network by executing inference.py. Then, execute utilities/scores.py to obtain evaluation metrics.

# VS dataset
python inference.py \
--model_dir ./models/VS/SCM/ \
--network U_Net2D5 \
--dataset_split ./splits/split_VS.csv \
--path_images ./data/VS/image_crop/ \
--image_postfix T2 \
--phase inference \
--spatial_shape 128 128 48 \
--epoch_inf best_model1

python utilities/scores.py \
--model_dir ./models/VS/SCM/ \
--network U_Net2D5 \
--dataset_split ./splits/split_VS.csv \
--image_postfix T2 \
--phase inference \
--path_labels ./data/VS/label_crop/ \
--label_postfix Label

# BraTS dataset
python inference.py \
--model_dir ./models/BraTS/SCM/ \
--network U_Net \
--dataset_split ./splits/split_BraTS.csv \
--path_images ./data/BraTS/image/ \
--image_postfix Flair \
--phase inference \
--spatial_shape 128 128 128 \
--epoch_inf best_model1

python utilities/scores.py \
--model_dir ./models/BraTS/SCM/ \
--network U_Net \
--dataset_split ./splits/split_BraTS.csv \
--image_postfix Flair \
--phase inference \
--path_labels ./data/BraTS/label/ \
--label_postfix Label

Our pre-trained models can be download from Google Drive or ALiYun Drive.

Acknowledgement

This code is adapted from InExtremIS. We thank Dr. Reuben Dorent for his elegant and efficient code base.