/DeSD

MICCAI 2022 Paper

Primary LanguagePython

DeSD-code

This is the official pytorch implementation of our MICCAI 2022 paper "DeSD: Self-Supervised Learning with Deep Self-Distillation for 3D Medical Image Segmentation". In this paper, we reformulate SSL in a Deep Self-Distillation (DeSD) manner to improve the representation quality of both shallow and deep layers.

The abstract of the paper is available here.

DeSD illustration

Requirements

CUDA 10.1
Python 3.6
Pytorch 1.7.1
Torchvision 0.8.2

Usage

Installation

  • Clone this repo.
git clone https://github.com/yeerwen/DeSD.git
cd DeSD

Data Preparation

Pre-processing

  • Run DL_save_nifti.py (from downloaded files) to transfer the PNG image to the nii.gz form.
  • Run re_spacing_ITK.py to resample CT volumes.
  • Run splitting_to_patches.py to extract about 125k sub-volumes, and the pre-processed dataset will be saved in DL_patches_v2/.

Training

  • Run sh run_ssl.sh for self-supervised pre-training.

Pre-trained Model

Fine-tune DeSD on your own target task

As for the target segmentation tasks, the 3D model can be initialized with the pre-trained encoder using the following example:

import torch
from torch import nn
# build a 3D segmentation model based on resnet50
class ResNet50_Decoder(nn.Module):
    def __init__(self, Resnet50_encoder, skip_connection, n_class=1, pre_training=True, load_path=None):
        super(ResNet50_Decoder, self).__init__()

        self.encoder = Resnet50_encoder
        self.decoder = Decoder(skip_connection)
        self.seg_head = nn.Conv3d(n_class, kernel_size=1)
        
        if pre_training:
            print('loading from checkpoint ssl: {}'.format(load_path))
            w_before = self.encoder.state_dict()['conv1.weight'].mean()
            pre_dict = torch.load(load_path, map_location='cpu')['teacher']
            pre_dict = {k.replace("module.backbone.", ""): v for k, v in pre_dict.items()}
            # print(pre_dict)
            model_dict = self.encoder.state_dict()
            pre_dict_update = {k:v for k, v in pre_dict.items() if k in model_dict}
            print("[pre_%d/mod_%d]: %d shared layers" % (len(pre_dict), len(model_dict), len(pre_dict_update)))
            model_dict.update(pre_dict_update)
            self.encoder.load_state_dict(model_dict)
            w_after = self.encoder.state_dict()['conv1.weight'].mean()
            print("one-layer before/after: [%.8f, %.8f]" % (w_before, w_after))
        else:
            print("TFS!")

    def forward(self, input):
        outs = self.encoder(input)
        decoder_out = self.decoder(outs)
        out = self.seg_head(decoder_out)
        return out

Citation

If this code is helpful for your study, please cite:

@article{DeSD,
  title={DeSD: Self-Supervised Learning with Deep Self-Distillation for 3D Medical Image Segmentation},
  author={Yiwen Ye, Jianpeng Zhang, Ziyang Chen, and Yong Xia},
  booktitle={Medical Image Computing and Computer Assisted Intervention -- MICCAI 2022},
  pages={545--555},
  year={2022}
}

Acknowledgements

Part of codes is reused from the DINO. Thanks to Caron et al. for the codes of DINO.

Contact

Yiwen Ye (ywye@mail.nwpu.edu.cn)