This repository is the official implementation of Self-Supervised Pretraining for 2D Medical Image Segmentation (accepted for the AIMIA workshop at ECCV 2022).
If you use our code or results, please cite our paper:
@InProceedings{Kalapos2022,
author = {Kalapos, Andr{\'a}s and Gyires-T{\'o}th, B{\'a}lint},
booktitle = {Computer Vision -- ECCV 2022 Workshops},
title = {{Self-supervised Pretraining for 2D Medical Image Segmentation}},
year = {2023},
address = {Cham},
pages = {472--484},
publisher = {Springer Nature Switzerland},
doi = {10.1007/978-3-031-25082-8_31},
isbn = {978-3-031-25082-8},
}
To install pypi requirements:
pip install -r requirements.txt
For self-supervised pre-training solo-learn==1.0.6
is also required. For it's installation, follow instructions in solo-learn's documentation (dali, umap support is not needed) or use the following commands:
git clone https://github.com/vturrisi/solo-learn.git
cd solo-learn
pip3 install -e .
Download the ACDC Segmentation dataset from: https://acdc.creatis.insa-lyon.fr (registration required)
Specify the path for the dataset in:
- supervised_segmentation/config_acdc.yaml ->
dataset_root: "<YOUR DATASET PATH>"
- self_supervised/configs/byol_acdc.yaml#L32 ->
train_path: "<YOUR DATASET PATH>"
For hyperparamter sweeps and running many experiments in a batch, we use Slurm jobs, therefore an installed and configured Slurm environment is required for these runs. However based on supervised_segmentation/sweeps/data_eff_learning.sh and supervised_segmentation/sweeps/grid_search_helper.py other methods of running experiments in a batch can be implemented.
To train the model(s) in the paper, run this command:
PYTHONPATH=. python supervised_segmentation/train.py
On a Slurm cluster:
PYTHONPATH=. srun -p gpu --gres=gpu --cpus-per-task=10 python supervised_segmentation/train.py
To initialize the downstream training with different pretrained models, we provide pretrained weights that we used in our paper. These can be selected by setting the encoder_weights
config in supervised_segmentation/config_acdc.yaml
For ImageNet pretraining, we acquire weights from:
timm
for supervised ImageNet pretraining- github.com/yaox12/BYOL-PyTorch for BYOL ImageNet pretraing [Google drive link for their model file]
To pretrain a model on the ACDC dataset, run this command:
python self_supervised/main_pretrain.py --config-path configs/ --config-name byol_acdc.yaml
Different pretraining strategies can be configured by specifying different --pretrained_weights
and --max_epochs
or by modifying the corresponding configs in self_supervised/config_acdc.yaml.
Pre-training approach | Corresponding arrow on the figure above | --pretrained_weights |
--max_epochs |
Published pretrained model [models.zip] |
---|---|---|---|---|
Supervised ImageNet + BYOL ACDC | (1st step) | supervised_imagenet |
25 | models/supervised-imagenet-byol-acdc-ep=25.pth |
BYOL ImageNet + BYOL ACDC | (1st step) | resnet50_byol_imagenet2012.pth.tar |
25 | models/byol-imagenet-acdc-ep=34.ckpt and models/byol-imagenet-acdc-ep=25.pth |
BYOL ACDC | None |
400 | models/byol_acdc_backbone_last.pth |
We publish pretrained models for these pretrainings as specified in the last column of the table
To evaluate the segmentation model on the ACDC dataset, run:
PYTHONPATH=. python supervised_segmentation/inference_acdc.py
Custom segmentation datasets can be loaded via mmsegmentations's custom dataset API. This requires the data to be formatted in the following directory structure:
├── data
│ ├── custom_dataset
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{img_suffix}
│ │ │ │ ├── yyy{img_suffix}
│ │ │ │ ├── zzz{img_suffix}
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{seg_map_suffix}
│ │ │ │ ├── yyy{seg_map_suffix}
│ │ │ │ ├── zzz{seg_map_suffix}
│ │ │ ├── val
The following configurations must be specified in data/config/custom.py:
data_root = PATH_TO_DATASET # Path to custom dataset
num_classes = NUM_CLASSES # Number of segmentation classes
in_channels = IN_CHANNELS # Number of input channels (e.g. 3 for RGB data)
class_labels = CLASS_LABELS # Class labels used for logging
ignore_label = IGNORE_LABEL # Ignored label during iou metric computation
img_suffix='.tiff'
seg_map_suffix = '_gt.tiff'
To train the model(s) run this command:
PYTHONPATH=. python supervised_segmentation/train.py --config_path supervised_segmentation/config_custom.yaml
If your custom dataset is in a simple image folder format, solo-learn's built in data loading should handle your dataset (including dali support). In this case you onyl need to specify the following paths in self_supervised/configs/byol_custom.yaml
train_path: "PATH_TO_TRAIN_DIR"
val_path: "PATH_TO_VAL_DIR" # optional
To run the SSL pretraining on a custom dataset:
python self_supervised/main_pretrain.py --config-path configs/ --config-name byol_custom.yaml
For more complex cases you can build a custom dataset class (similar to data.acdc_dataset.ACDCDatasetUnlabeleld) and instantiate it in self_supervised/main_pretrain.py#L183.
Segmentation code is based on: ternaus/cloths_segmentation
Self-supervised training script is based on: vturrisi/solo-learn
BYOL ImageNet pretrained model from: yaox12/BYOL-PyTorch [Google drive link for their model file]
This Readme is based on: paperswithcode/releasing-research-code