This repo is the official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer"
We propose a Channel Transformer module (CTrans) and use it to replace the skip connections in original U-Net, thus we name it "U-CTrans-Net".
Install from the requirements.txt
using:
pip install -r requirements.txt
The original data can be downloaded in following links:
- MoNuSeG Dataset - Link (Original)
- GLAS Dataset - Link (Original)
Then prepare the datasets in the following format for easy use of the code:
├── datasets
├── GlaS
│ ├── Test_Folder
│ │ ├── img
│ │ └── labelcol
│ ├── Train_Folder
│ │ ├── img
│ │ └── labelcol
│ └── Val_Folder
│ ├── img
│ └── labelcol
└── MoNuSeg
├── Test_Folder
│ ├── img
│ └── labelcol
├── Train_Folder
│ ├── img
│ └── labelcol
└── Val_Folder
├── img
└── labelcol
The Synapse dataset we used is provided by TransUNet's authors. Please go to https://github.com/Beckschen/TransUNet/blob/main/datasets/README.md for details.
As mentioned in the paper, we introduce two strategies to optimize UCTransNet.
The first step is to change the settings in Config.py
,
all the configurations including learning rate, batch size and etc. are
in it.
We optimize the convolution parameters in U-Net and the CTrans parameters together with a single loss. Run:
python train_model.py
Our method just replaces the skip connections in U-Net, so the parameters in U-Net can be used as part of pretrained weights.
By first training a classical U-Net using /nets/UNet.py
then using the pretrained weights to train the UCTransNet,
CTrans module can get better initial features.
This strategy can improve the convergence speed and may improve the final segmentation performance in some cases.
Here, we provide pre-trained weights on GlaS and MoNuSeg, if you do not want to train the models by yourself, you can download them in the following links:
- GlaS:https://drive.google.com/file/d/1ciAwb2-0G1pZrt_lgSwd-7vH1STmxdYe/view?usp=sharing
- MoNuSeg: https://drive.google.com/file/d/1CJvHoh3VrPsBn_njZDo6SvJF_yAVe5MK/view?usp=sharing
First, change the session name in Config.py
as the training phase.
Then run:
python test_model.py
You can get the Dice and IoU scores and the visualization results.
If this code is helpful for your study, please cite:
@misc{wang2021uctransnet,
title={UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer},
author={Haonan Wang and Peng Cao and Jiaqi Wang and Osmar R. Zaiane},
year={2021},
eprint={2109.04335},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Haonan Wang (haonan1wang@gmail.com)