Code for Pseudo label based contrastive learning joint training approach
Local contrastive loss with pseudo-label based self-training for semi-supervised medical image segmentation
The code is for the article "Local contrastive loss with pseudo-label based self-training for semi-supervised medical image segmentation" under review. With the proposed joint-training method using Contrastive loss, we get competitive segmentation performance with just 2 labeled training volumes compared to upperbound and compared methods.
https://arxiv.org/abs/2112.09645
Authors:
Krishna Chaitanya (email),
Ertunc Erdil,
Neerav Karani,
Ender Konukoglu.
Requirements:
Python 3.6.1,
Tensorflow 1.12.0,
rest of the requirements are mentioned in the "requirements.txt" file.
I) To clone the git repository.
git clone https://github.com/krishnabits001/pseudo_label_contrastive_training.git
II) Install python, required packages and tensorflow.
Then, install python packages required using below command or the packages mentioned in the file.
pip install -r requirements.txt
To install tensorflow
pip install tensorflow-gpu=1.12.0
III) Dataset download.
To download the ACDC Cardiac dataset, check the website :
https://www.creatis.insa-lyon.fr/Challenge/acdc.
To download the Medical Decathlon Prostate dataset, check the website :
http://medicaldecathlon.com/
To download the MMWHS Cardiac dataset, check the website :
http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/mmwhs/
All the images were bias corrected using N4 algorithm with a threshold value of 0.001. For more details, refer to the "N4_bias_correction.py" file in scripts.
Image and label pairs are re-sampled (to chosen target resolution) and cropped/zero-padded to a fixed size using "create_cropped_imgs.py" file.
IV) Train the model.
To do joint training run the script "pseudo_lbl_rand_init.sh" in train_model directory.
For instance, if we want to train for ACDC dataset with 2 training volumes and configuration c1 use below step.
bash pseudo_lbl_rand_init.sh tr2 c1 acdc
Above command, executes the below 2 steps of training:
Steps :
1) In Step 1: Train a baseline network model to infer the initial pseudo-labels for unlabeled data. This training is only done once at the start.
cd train_model/
python tr_baseline.py --no_of_tr_imgs=tr2 --comb_tr_imgs=c1 --dataset=acdc
- In Step 2: Post Step 1, we infer pseudo-labels of unlabeled data and perform the joint training based on contrastive loss and segmentation loss. This training is done iteratively, where the pseudo-labels are refined periodicallt.
python prop_method_joint_tr_rand_init.py --no_of_tr_imgs=tr2 --comb_tr_imgs=c1 --dataset=acdc
V) Config files contents.
One can modify the contents of the below 2 config files to run the required experiments.
experiment_init directory contains 2 files.
Example for ACDC dataset:
- init_acdc.py
--> contains the config details like target resolution, image dimensions, data path where the dataset is stored and path to save the trained models. - data_cfg_acdc.py
--> contains an example of data config details where one can set the patient ids which they want to use as train, validation and test images.
Bibtex citation:
@article{chaitanya2021local, title={Local contrastive loss with pseudo-label based self-training for semi-supervised medical image segmentation}, author={Chaitanya, Krishna and Erdil, Ertunc and Karani, Neerav and Konukoglu, Ender}, journal={arXiv preprint arXiv:2112.09645}, year={2021} }