MI-Seg is a framework based on MONAI libray for Cross-Modality
clinical images Segmentation using Conditional Models and Interleaved Training.
Explore the docs »
Report Bug
·
Request Feature
Table of Contents
Our paper has been accepted at ICCVW 2023 and is available here and on ArXiv. Please cite our work with
@InProceedings{Bastico_2023_ICCV,
author = {Bastico, Matteo and Ryckelynck, David and Cort\'e, Laurent and Tillier, Yannick and Decenci\`ere, Etienne},
title = {A Simple and Robust Framework for Cross-Modality Medical Image Segmentation Applied to Vision Transformers},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
month = {October},
year = {2023},
pages = {4128-4138}
}
Our released implementation is tested on:
- Ubuntu 22.04
- Python 3.10.8
- PyTorch 1.13.1 and PyTorch Lightning 1.8.6
- Ray 2.2.0
- NVIDIA CUDA 11.7
- Monai 1.1.0
- Optuna 3.1.0
- Clone our project folder
- Create and lunch your conda environment with
conda create -n MI-Seg python=3.10.8 conda activate MI-Seg
- Install dependencies
Note: for Pytorch CUDA installation follow https://pytorch.org/get-started/locally/.
pip install -r requirements.txt
The dataset used in our experiments can be downloaded here upon access request.
Download and unzip it into /dataset/MM-WHS
folder.
[Optional] Convert label and Perform N4 Bias Correction of MRIs using the provided Notebook load_data.ipynb
You should end up with a similar data structure (sub-folders are not represented here)
MM-WHS
├── ct_train # Ct training folder
│ ├── ct_train_1001_image.nii.gz # Image
│ ├── ct_train_1001_label.nii.gz # Label
│ ...
├── ct_test
├── mr_train
├── mr_test
...
The splits we used for our cross_validation are provided in CT_fold1.json
and CT_fold2.json
.
To train a model you can use the train.py
script provided. Single training are based on PyTorch Lightning and
all the Trainer arguments can be passed to the script
(see here). Additionally, we provide model,
data and logger-specific arguments. To have a full list of the possible arguments execute python train.py --help
.
An example of C-Swin-UNETR training on single GPU is shown in the following
python train.py --model_name=swin_unetr --out_channels=6 --feature_size=48 --num_heads=3 --accelerator=gpu --devices=1 --max_epochs=2500 --encoder_norm_name=instance_cond --vit_norm_name=instance_cond --lr=1e-4 --batch_size=1 --patches_training_sample=1
The available models are unet, unetr and swin_unetr and pre_swin_unetr (in this case the pretrained model of monai
must be provided as --pre_swin
.
Furthermore, we use WandB to log the experiments and specifications can be set as arguments.
In the previous example wandb will run in online mode, so you need to provided login and API key. To change wandb mode set
wandb_mode=offline
.
Note:
AMP (--no_amp
) should be disabled with checkpointing to save memory during training of Swin_Unetr based models (--use_checkpoint
).
Our pre-trained models can be downloaded here and tested with the test.py
script. The path of the model weights
should be provided as --checkpoint
(note that the model weight should be under the state_dict
key).
Example:
python test.py --out_channels=6 --model_name=swin_unetr --num_workers=2 --feature_size=48 --num_heads=3 --encoder_norm_name=instance_cond --vit_norm_name=instance_cond --checkpoint=experiments/<path>
Hyper-parameters optimization is based on Optuna. For the moment, the script supports automatic setup of distributed tuning ONLY on Slurm environments. Therefore, it needs to be adapted by the user to run in different multi-GPUs enviroments.
The hyper-parameters grid is set in automatic for each model as stated in our paper and the tuning can be started as in the following.
The script will run 10 trials, with TPE optimizer and ASHA pruner, and save the in the MI-Seg.log
log file (if Slurm) or MI-Seg.sqlite
(if not Slurm).
python -u tune.py --num_workers=2 --out_channels=6 --no_include_background --criterion=generalized_dice_focal --scheduler=warmup_cosine --model_name=swin_unetr --n_trials=10 --study_name=c-swin-unetr --max_epochs=2500 --check_val_every_n_epoch=50 --batch_size=1 --patches_training_sample=4 --iters_to_accumulate=4 --cycles=0.5 --storage_name=MI-Seg --min_lr=1e-5 --max_lr=1e-3 --vit_norm_name=instance_cond --encoder_norm_name=instance_cond --port=23456
The script can be run multiple time with the same --storage_name
in order to continue a previous tuning.
To open log files dashboards not stored as RDB, we provide the utils/run_server.py --path=<storage>
script.
The dashboard of our tuning presented in the paper is available at experiments/optuna/MI-Seg.log
and can be open with
python utils/run_server.py --path=experiments/optuna/MI-Seg.log
The best pre-trained model weights for Conditional UNet and Swin-UNETR resulting from our hyper-parameters optimization can be downloaded here.
For instance, to produce the segmentation on the test dataset using the provided weights you can run for Conditional UNet:
python predict_whs.py --model=unet_vanilla --encoder_norm_name=instance_cond --feature_size 16 64 128 256 512 --num_res_units=3 --strides 1 2 2 2 1 --out_channels=8 --checkpoint=path/to/weights.pt --result_dir=path/to/result
or for Conditional Swin-UNETR:
python -u predict_whs.py --model=swin_unetr --encoder_norm_name=instance_cond --vit_norm_name=instance_cond --feature_size=36 --num_heads=4 --out_channels=8 --checkpoint=path/to/weights.pt --result_dir=path/to/result
- Implement LN for convolutional layers of Monai (testing purposes)
- Implement distributed tuning on not-Slurm environment
See the open issues for a full list of proposed features (and known issues).
If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". Don't forget to give the project a star! Thanks again!
- Fork the Project
- Create your Feature Branch (
git checkout -b feature/my_feature
) - Commit your Changes (
git commit -m 'Add my_feature'
) - Push to the Branch (
git push origin feature/my_feature
) - Open a Pull Request
Distributed under the MIT (or other) License. See LICENSE.txt
for more information.
Matteo Bastico - @matteobastico - matteo.bastico@minesparis.psl.eu
Project Link: https://github.com/matteo-bastico/MI-Seg
This work was supported by the H2020 European Project ...