/DASQE

Primary LanguagePython

DASQE

This is the implementation of the "A Collaborative Self-supervised Domain Adaptation for Low-Quality Medical Image Enhancement".

Data preparation

  1. Firstly, download the datasets from the following links:

  2. Split the datasets into train/test.

Construction of multiple patch domains

  1. Serialize the original images via 'patch_serialize.py'.

  2. Initialize high-/low-quality domains $\mathbb{H}$ and $\mathbb{L}$ by clustering algorithm in folder './LocalAggregation'. Next, update $\mathbb{H}$ and $\mathbb{L}$ through 'quality_assessment_scheme.py'.

  3. Construct source style domains $\mathbb{S}$ and target style domains $\mathbb{T}$ from $\mathbb{H}$ through clustering algorithm in folder './LocalAggregation'.

Then prepare the datasets in the following format for easy training:

(data_l, data_s, and data_t represent patches from $\mathbb{L}$, $\mathbb{S}$ and $\mathbb{T}$, respectively)

├── dataset
│   ├── EyeQ
│   │   ├── Test_Folder
│   │   │   ├── data_l  
│   │   │   ├── data_s
│   │   │   └── data_t
│   │   └── Train_Folder
│   │   │   ├── data_l
│   │   │   ├── data_s
│   │   │   └── data_t
│   ├── Corneal Nerve
│   │   ├── Test_Folder
│   │   │   ├── data_l
│   │   │   ├── data_s
│   │   │   └── data_t
│   │   └── Train_Folder
│   │   │   ├── data_l
│   │   │   ├── data_s
│   │   │   └── data_t
│   ├── ISIC
│   │   ......
│   ├── Endoscopy
│   │   ......
│   ├── Chest X-ray
│   │   ......
│   ├── Cardiac MRI
│   │   ......

Package install

Run

pip install -r requirements.txt

Training

Change the settings in configs/unit_noise2clear-bn.yaml and run the DASQE model

python train.py --output_path ${LOG_DIR}$

Testing and visualization

Run

python test.py --input_a ${Low_quality_path}$ --input_c ${target_style_path}$ --output_folder ${output_image_path}$ --checkpoint ${pretrain_model_path}$ --psnr