By Yuncheng Jiang*, Zixun Zhang*, Shixi Qin, Yao guo, Zhen Li, Shuguang Cui.
This repo is a PyTorch implementation of "APAUNet: Axis Projection Attention UNet for Small Target Segmentation in 3D Medical Images", accepted by ACCV 2022.
APAUNet is a segmentation network for 3D medical image data. It aims to improve the small targe segmentation accuracy using projection 2D attention mechanism on three axes. For more details, please refer to the paper.
The model was trained and test on the Medical Segmentation Decathlon Task 08 (Hepatic Vessels). APAUNet architecture is an architecture that has shown specially good results in segmentation of objects with a small size (like tumors or organs), but it has never been tried before this in a vessel segmentation context.
The results obtained for vessel segmentation is competitive to the public results of the Medical Segmentation Decathlon. The code was modified to run it in Polyaxon. Each changed performed in the code is explained with a comment and some files (like tester.py) has been added.
Datasets can be acquired via following links:
After you have downloaded the datasets, you can follow the settings in * nnUNet for path configurations and preprocessing procedures.
Finally, your folders should be organized as follows:
./DATASET/
├── nnUNet_raw/
├── nnUNet_raw_data/
├── Task01_Liver/
├── imagesTr/
├── imagesTs/
├── labelsTr/
├── labelsTs/
├── dataset.json
├── Task02_Pancreas/
├── imagesTr/
├── imagesTs/
├── labelsTr/
├── labelsTs/
├── dataset.json
├── Task03_Synapse/
├── imagesTr/
├── imagesTs/
├── labelsTr/
├── labelsTs/
├── dataset.json
├── nnUNet_cropped_data/
├── nnUNet_trained_models/
├── nnUNet_preprocessed/
We use a NVIDIA Preprocessing
git clone https://github.com/NVIDIA/DeepLearningExamples.git
cd DeepLearningExamples/PyTorch/Segmentation/nnUNet
After that, you can preprocess the above data using following commands:
python preprocess.py --task 08 --dim 3 --data "<path_to_dataset>" --results "<output_folder>"
Then the pre-processed data is stored in the output folder
The code is set to run directly in Polyaxon, so each path is relative to the NAS folder.
All the training and testing hyper-parameters are set in config.py. You can modify these configurations according to your requirements.
model_name
: (string) name of model used (default APAUNet)dataset
: (string) name of dataset used (default Task08_HepaticVessel)data_path
: (string) folder path where is located the dataset (relative to NAS).scheduler
: (string) learning rate scheduler to use (available options: CosineAnnealingLR, MultiStepLR, StepLR)criterion
: (string) learning rate scheduler to use (available options: DiceFocal, Dice, DiceWeighted, BinaryFocalLoss, DiceBCELoss)optimizer
: (string) optimizer to use (available options: SGD, Adam, adamw)gamma
: (float) gamma parameter for the Focal Lossalpha
: (float) alpha parameter for the Focal Losslr
: (float) learning rate for training the modelepochs
: (int) number of epochs to train the modelval_interval
: (int) number of epochs between checks of results in the validation set.batch_size
: (int) number of volumes in batch.in_ch
: (int) number of channels in the input volume (default to 1)class_num
: (int) number of output classes (default to 1)val_num
: (int) number of volumes to use for validation.input_shape
: (int, int, int) shape for bounding crops of the input volumes to be introduced in the model.resume
: (string) folder path where is located the .pth torch model to continue training.use_cuda
: (bool) use gpu for using the model.debug
: (bool) flag for debugging changes in the training loop (parameters changed to epochs=30, val_num=2, data_path="path_to_small_training_set")
model_name
: (string) name of model used (default APAUNet)dataset
: (string) name of dataset used (default Task08_HepaticVessel)data_dir
: (string) folder path where is located the dataset (relative to NAS).batch_size
: (int) number of volumes in batch.in_ch
: (int) number of channels in the input volume (default to 1)class_num
: (int) number of output classes (default to 1)val_num
: (int) number of volumes to use for validation.input_shape
: (int, int, int) shape for bounding crops of the input volumes to be introduced in the model.resume
: (string) folder path where is located the .pth torch model to be tested.use_cuda
: (bool) use gpu for using the model.
python train.py
Notice that in order to be used in polyaxon, a polyaxonfile has been already prepared:
polyaxon run -p <project_name> -f polyaxonfile.yaml
python tester.py
It will produce a predicted segmentation mask for the given testing data.
Notice that in order to be used in polyaxon, a different polyaxonfile has been already prepared:
polyaxon run -p <project_name> -f polyaxonfileTest.yaml
The dice score (%) results on testing for the test dataset created from Task 8 of Medical Segmentation Decathlon are possible to see it in the next table.
Architecture | Dice Score |
---|---|
APAUNet | 65.4% |
CLIP* | 67% |
*CLIP is the model ranked as first in the Medical Segmentation Decathlon by 27.03.2023
@inproceedings{apaunet2022,
title={APAUNet: Axis Projection Attention UNet for Small Target Segmentation in 3D Medical Images},
author={Jiang, Yuncheng and Zhang, Zixun and Qin, Shixi and Guo, Yao and Li, Zhen and Cui, Shuguang},
booktitle={Proceedings of the Asian Conference on Computer Vision},
year={2022}
}