/XNet

[ICCV2023] XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images

Primary LanguagePython

XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images

This is the official code of XNet: Wavelet-Based Low and High Frequency Merging Networks for Semi- and Supervised Semantic Segmentation of Biomedical Images (ICCV 2023).

Overview


Architecture of XNet.


Visualize dual-branch inputs. (a) Raw image. (b) Wavelet transform results. (c) Low frequency image. (d) High frequency image.


Architecture of LF and HF fusion module.

Quantitative Comparison

Comparison with fully- and semi-supervised state-of-the-art models on GlaS and CREMI test set. Semi-supervised models are based on UNet. DS indicates deep supervision. * indicates lightweight models. ‡ indicates training for 1000 epochs. - indicates training failed. Red and bold indicate the best and second best performance.

Comparison with fully- and semi-supervised state-of-the-art models on LA and LiTS test set. Due to GPU memory limitations, some semi-supervised models using smaller architectures, ✝ and * indicate models are based on lightweight 3D UNet (half of channels) and VNet, respectively. ‡ indicates training for 1000 epochs. - indicates training failed. Red and bold indicate the best and second best performance.

Qualitative Comparison


Qualitative results on GIaS, CREMI, LA and LiTS. (a) Raw images. (b) Ground truth. (c) MT. (d) Semi-supervised XNet (3D XNet). (e) UNet (3D UNet). (f) Fully-Supervised XNet (3D XNet). The orange arrows highlight the difference among of the results.

Reimplemented Architecture

We have reimplemented some 2D and 3D models in semi- and supervised semantic segmentation.

Method DimensionModelCode
Supervised 2DUNetmodels/networks_2d/unet.py
UNet++models/networks_2d/unet_plusplus.py
Att-UNetmodels/networks_2d/unet.py
Aerial LaneNetmodels/networks_2d/aerial_lanenet.py
MWCNNmodels/networks_2d/mwcnn.py
HRNetmodels/networks_2d/hrnet.py
Res-UNetmodels/networks_2d/resunet.py
WDSmodels/networks_2d/wds.py
U2-Netmodels/networks_2d/u2net.py
UNet 3+models/networks_2d/unet_3plus.py
SwinUNetmodels/networks_2d/swinunet.py
WaveSNetmodels/networks_2d/wavesnet.py
XNet (Ours)models/networks_2d/xnet.py
3DVNetmodels/networks_3d/vnet.py
UNet 3Dmodels/networks_3d/unet3d.py
Res-UNet 3Dmodels/networks_3d/res_unet3d.py
ESPNet 3Dmodels/networks_3d/espnet3d.py
DMFNet 3Dmodels/networks_3d/dmfnet.py
ConResNetmodels/networks_3d/conresnet.py
CoTrmodels/networks_3d/cotr.py
TransBTSmodels/networks_3d/transbts.py
UNETRmodels/networks_3d/unetr.py
XNet 3D (Ours)models/networks_3d/xnet3d.py
Semi-Supervised 2DMTtrain_semi_MT.py
EMtrain_semi_EM.py
UAMTtrain_semi_UAMT.py
CCTtrain_semi_CCT.py
CPStrain_semi_CPS.py
URPCtrain_semi_URPC.py
CTtrain_semi_CT.py
XNet (Ours)train_semi_XNet.py
3DMTtrain_semi_MT_3d.py
EMtrain_semi_EM_3d.py
UAMTtrain_semi_UAMT_3d.py
CCTtrain_semi_CCT_3d.py
CPStrain_semi_CPS_3d.py
URPCtrain_semi_URPC_3d.py
CTtrain_semi_CT_3d.py
DTCtrain_semi_DTC.py
XNet 3D (Ours)train_semi_XNet3d.py

Requirements

albumentations==0.5.2
einops==0.4.1
MedPy==0.4.0
numpy==1.20.2
opencv_python==4.2.0.34
opencv_python_headless==4.5.1.48
Pillow==8.0.0
PyWavelets==1.1.1
scikit_image==0.18.1
scikit_learn==1.0.1
scipy==1.4.1
SimpleITK==2.1.0
timm==0.6.7
torch==1.8.0+cu111
torchio==0.18.53
torchvision==0.9.0+cu111
tqdm==4.65.0
visdom==0.1.8.9

Usage

Data preparation Your datasets directory tree should be look like this:

to see tools/wavelet2D.py and tools/wavelet3D.py for L and H

dataset
├── train_sup_100
    ├── L
        ├── 1.tif
        ├── 2.tif
        └── ...
    ├── H
        ├── 1.tif
        ├── 2.tif
        └── ...
    └── mask
        ├── 1.tif
        ├── 2.tif
        └── ...
├── train_sup_20
    ├── L
    ├── H
    └── mask
├── train_unsup_80
    └── L
    ├── H
└── val
    ├── L
    ├── H
    └── mask

Supervised training

python -m torch.distributed.launch --nproc_per_node=4 train_sup_XNet.py

Semi-supervised training

python -m torch.distributed.launch --nproc_per_node=4 train_semi_XNet.py

Testing

python -m torch.distributed.launch --nproc_per_node=4 test.py

Citation

If our work is useful for your research, please cite our paper:

@InProceedings{Zhou_2023_ICCV,
  author = {Zhou, Yanfeng and Huang, Jiaxing and Wang, Chenlong and Song, Le and Yang, Ge}, 
  title = {XNet: Wavelet-Based Low and High Frequency Fusion Networks for Fully- and Semi-Supervised Semantic Segmentation of Biomedical Images}, 
  booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 
  month = {October}, 
  year = {2023}, 
  pages = {21085-21096}
  }