/SSL-FL

Self-supervised federated learning for medical imaging

Primary LanguagePython

Self-supervised Federated Learning (SSL-FL)

Label-Efficient Self-Supervised Federated Learning for Tackling Data Heterogeneity in Medical Imaging

HTML | ArXiv | Cite

*TL;DR: Pytorch implementation of the self-supervised federated learning framework proposed in our paper for simulating self-supervised classification on multi-institutional medical imaging data using federated learning.

  • Our framework employs masked image encoding as self-supervised task to learn efficient representations from images.
  • Extensive experiments are performed on diverse medical datasets including retinal images, dermatology images and chest X-rays.
  • In particular, we implement BEiT and MAE as the self-supervision learning module.

Reference

If you find our work helpful in your research or if you use any source codes or datasets, please cite our paper. The bibtex is listed below:

@article{yan2023label,
  title={Label-efficient self-supervised federated learning for tackling data heterogeneity in medical imaging},
  author={Yan, Rui and Qu, Liangqiong and Wei, Qingyue and Huang, Shih-Cheng and Shen, Liyue and Rubin, Daniel and Xing, Lei and Zhou, Yuyin},
  journal={IEEE Transactions on Medical Imaging},
  year={2023},
  publisher={IEEE}
}

Pre-requisites:

Set Up Environment

  • conda env create -f environment.yml
  • NVIDIA GPU (Tested on Nvidia Tesla V100 32G x 4, and Nvidia GeForce RTX 2080 Ti x 8) on local workstations
  • Python (3.8.12), torch (1.7.1), timm (0.3.2), numpy (1.21.2), pandas (1.4.2), scikit-learn (1.0.2), scipy (1.7.1), seaborn (0.11.2)

Data Preparation

Please refer to SSL-FL/data for information on the directory structures of data folders, download links to datasets, and instructions on how to train on custom datasets.

Self-supervised Federated Learning for Medical Image Classification

In this paper, we selected ViT-B/16 as the backbone for all methods. The specifications for BEiT-B are as follows: #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16 (#parameters: 86M).

Please refer to SSL-FL/data for access to the links to pre-trained checkpoints that were used to generate the results.

Self-supervised Federated Pre-training and fine-tuning

Sample scripts for running Fed-BEiT and Fed-MAE pre-training and finetuning on the Retina dataset can be found in the following directories: SSL-FL/code/fed_beit/script/retina for Fed-BEiT and SSL-FL/code/fed_mae/script/retina for Fed-MAE.

To run Fed-BEiT, please download Dall-e tokenizers and save encoder.pkl and decoder.pkl to SSL-FL/data/tokenizer_weight:

  • wget https://cdn.openai.com/dall-e/encoder.pkl
  • wget https://cdn.openai.com/dall-e/decoder.pkl

Acknowledgements