DeSD-code
This is the official pytorch implementation of our MICCAI 2022 paper "DeSD: Self-Supervised Learning with Deep Self-Distillation for 3D Medical Image Segmentation". In this paper, we reformulate SSL in a Deep Self-Distillation (DeSD) manner to improve the representation quality of both shallow and deep layers.
The abstract of the paper is available here.
Requirements
CUDA 10.1
Python 3.6
Pytorch 1.7.1
Torchvision 0.8.2
Usage
Installation
- Clone this repo.
git clone https://github.com/yeerwen/DeSD.git
cd DeSD
Data Preparation
- Download DeepLesion dataset.
Pre-processing
- Run
DL_save_nifti.py
(from downloaded files) to transfer the PNG image to the nii.gz form. - Run
re_spacing_ITK.py
to resample CT volumes. - Run
splitting_to_patches.py
to extract about 125k sub-volumes, and the pre-processed dataset will be saved inDL_patches_v2/
.
Training
- Run
sh run_ssl.sh
for self-supervised pre-training.
Pre-trained Model
- Pre-trained model is available in DeSD_Res50.
Fine-tune DeSD on your own target task
As for the target segmentation tasks, the 3D model can be initialized with the pre-trained encoder using the following example:
import torch
from torch import nn
# build a 3D segmentation model based on resnet50
class ResNet50_Decoder(nn.Module):
def __init__(self, Resnet50_encoder, skip_connection, n_class=1, pre_training=True, load_path=None):
super(ResNet50_Decoder, self).__init__()
self.encoder = Resnet50_encoder
self.decoder = Decoder(skip_connection)
self.seg_head = nn.Conv3d(n_class, kernel_size=1)
if pre_training:
print('loading from checkpoint ssl: {}'.format(load_path))
w_before = self.encoder.state_dict()['conv1.weight'].mean()
pre_dict = torch.load(load_path, map_location='cpu')['teacher']
pre_dict = {k.replace("module.backbone.", ""): v for k, v in pre_dict.items()}
# print(pre_dict)
model_dict = self.encoder.state_dict()
pre_dict_update = {k:v for k, v in pre_dict.items() if k in model_dict}
print("[pre_%d/mod_%d]: %d shared layers" % (len(pre_dict), len(model_dict), len(pre_dict_update)))
model_dict.update(pre_dict_update)
self.encoder.load_state_dict(model_dict)
w_after = self.encoder.state_dict()['conv1.weight'].mean()
print("one-layer before/after: [%.8f, %.8f]" % (w_before, w_after))
else:
print("TFS!")
def forward(self, input):
outs = self.encoder(input)
decoder_out = self.decoder(outs)
out = self.seg_head(decoder_out)
return out
Citation
If this code is helpful for your study, please cite:
@article{DeSD,
title={DeSD: Self-Supervised Learning with Deep Self-Distillation for 3D Medical Image Segmentation},
author={Yiwen Ye, Jianpeng Zhang, Ziyang Chen, and Yong Xia},
booktitle={Medical Image Computing and Computer Assisted Intervention -- MICCAI 2022},
pages={545--555},
year={2022}
}
Acknowledgements
Part of codes is reused from the DINO. Thanks to Caron et al. for the codes of DINO.
Contact
Yiwen Ye (ywye@mail.nwpu.edu.cn)