This repo is the official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer" which is accepted at AAAI2022.
We propose a Channel Transformer module (CTrans) and use it to replace the skip connections in original U-Net, thus we name it "U-CTrans-Net".
Install from the requirements.txt
using:
pip install -r requirements.txt
The original data can be downloaded in following links:
- MoNuSeG Dataset - Link (Original)
- GLAS Dataset - Link (Original)
Then prepare the datasets in the following format for easy use of the code:
├── datasets
├── GlaS
│ ├── Test_Folder
│ │ ├── img
│ │ └── labelcol
│ ├── Train_Folder
│ │ ├── img
│ │ └── labelcol
│ └── Val_Folder
│ ├── img
│ └── labelcol
└── MoNuSeg
├── Test_Folder
│ ├── img
│ └── labelcol
├── Train_Folder
│ ├── img
│ └── labelcol
└── Val_Folder
├── img
└── labelcol
The Synapse dataset we used is provided by TransUNet's authors. Please go to https://github.com/Beckschen/TransUNet/blob/main/datasets/README.md for details.
As mentioned in the paper, we introduce two strategies to optimize UCTransNet.
The first step is to change the settings in Config.py
,
all the configurations including learning rate, batch size and etc. are
in it.
We optimize the convolution parameters in U-Net and the CTrans parameters together with a single loss. Run:
python train_model.py
Our method just replaces the skip connections in U-Net, so the parameters in U-Net can be used as part of pretrained weights.
By first training a classical U-Net using /nets/UNet.py
then using the pretrained weights to train the UCTransNet,
CTrans module can get better initial features.
This strategy can improve the convergence speed and may improve the final segmentation performance in some cases.
Here, we provide pre-trained weights on GlaS and MoNuSeg, if you do not want to train the models by yourself, you can download them in the following links:
- GlaS:https://drive.google.com/file/d/1ciAwb2-0G1pZrt_lgSwd-7vH1STmxdYe/view?usp=sharing
- MoNuSeg: https://drive.google.com/file/d/1CJvHoh3VrPsBn_njZDo6SvJF_yAVe5MK/view?usp=sharing
First, change the session name in Config.py
as the training phase.
Then run:
python test_model.py
You can get the Dice and IoU scores and the visualization results.
In our code, we carefully set the random seed and set cudnn as 'deterministic' mode to eliminate the randomness. However, there still exsist some factors which may cause different training results, e.g., the cuda version, GPU types, the number of GPUs and etc. The GPU used in our experiments is NVIDIA A40 (48G) and the cuda version is 11.2.
Especially for multi-GPU cases, the upsampling operation has big problems with randomness. See https://pytorch.org/docs/stable/notes/randomness.html for more details.
When training, we suggest to train the model twice to verify wheather the randomness is eliminated. Because we use the early stopping strategy, the final performance may change significantly due to the randomness.
If this code is helpful for your study, please cite:
@misc{wang2021uctransnet,
title={UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer},
author={Haonan Wang and Peng Cao and Jiaqi Wang and Osmar R. Zaiane},
year={2021},
eprint={2109.04335},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Haonan Wang (haonan1wang@gmail.com)