This is the implementation of Towards Unifying Medical Vision-and-Language Pre-training via Soft Prompts at ICCV-2023.
Run the following command to install the required packages:
pip install -r requirements.txt
You can either (1) preprocess the data by yourself following the instruction; or (2) directly apply for our preprocessed data here (Please make sure you attach the license) with the attached PhysioNet license (e.g., a link to the screenshot of the license) (download the data
and downloaded
folder). Please make sure you attach the license.
The project structure should be:
| +--datasets
| +--datamodules
| +--metrics
| +--models
| +--pretrain_arrows
| +--finetune_arrows
| +--finetune_vision_arrows
| +--finetune_language_arrows
| +--roberta-base
| +--biomed_roberta_base
| +--scorers
| +--meter.ckpt
| +--ptunifier.ckpt
Please organize the pre-training datasets as the following structure:
| +--roco
| | +--val
| | +--test
| | +--train
| +--mimic_cxr
| | +--files
| | +--mimic-cxr-2.0.0-split.csv
| | +--mimic-cxr-2.0.0-metadata.csv
| | +--mimic-cxr-2.0.0-chexpert.csv
| | +--mimic_cxr_sectioned.csv
| +--medicat
| | +--release
| | +--net
Run the following command to pre-process the data:
python prepro/
to get the following arrow files:
| +--medicat_train.arrow
| +--medicat_val.arrow
| +--medicat_test.arrow
| +--roco_train.arrow
| +--roco_val.arrow
| +--roco_test.arrow
| +--mimic_cxr_train.arrow
| +--mimic_cxr_val.arrow
| +--mimic_cxr_test.arrow
Download the initialized meter weights here.
Now we can start to pre-train the ptunifer model:
bash run_scripts/
Please organize the fine-tuning datasets as the following structure:
| +--melinda
| | +--train.csv
| | +--dev.csv
| | +--test.csv
| | +--melinda_images
| +--slack
| | +--train.json
| | +--validate.json
| | +--test.json
| | +--imgs
| +--vqa_rad
| | +--trainset.json
| | +--valset.json
| | +--testset.json
| | +--images
| +--medvqa_2019
| | +--val
| | +--test
| | +--train
| +--chexpert
| | +--CheXpert-v1.0-small
| +--rsna_pneumonia
| | +--stage_2_test_images
| | +--stage_2_train_labels.csv
| | +--stage_2_train_images
| +--mednli
| | +--mli_train_v1.jsonl
| | +--mli_test_v1.jsonl
| | +--mli_dev_v1.jsonl
| +--radnli
| | +--radnli_pseudo-train.jsonl
| | +--radnli_test_v1.jsonl
| | +--radnli_dev_v1.jsonl
Run the following command to pre-process the data:
python prepro/
to get the following arrow files:
| +--vqa_vqa_rad_train.arrow
| +--vqa_vqa_rad_val.arrow
| +--vqa_vqa_rad_test.arrow
| +--vqa_slack_train.arrow
| +--vqa_slack_test.arrow
| +--vqa_slack_val.arrow
| +--vqa_medvqa_2019_train.arrow
| +--vqa_medvqa_2019_val.arrow
| +--vqa_medvqa_2019_test.arrow
| +--irtr_roco_train.arrow
| +--irtr_roco_val.arrow
| +--irtr_roco_test.arrow
| +--mlc_chexpert_train_001.arrow
| +--mlc_chexpert_train_01.arrow
| +--mlc_chexpert_train.arrow
| +--mlc_chexpert_val.arrow
| +--mlc_chexpert_test.arrow
| +--mlc_pnsa_pneumonia_train_001.arrow
| +--mlc_pnsa_pneumonia_train_01.arrow
| +--mlc_pnsa_pneumonia_train.arrow
| +--mlc_pnsa_pneumonia_val.arrow
| +--mlc_pnsa_pneumonia_test.arrow
| +--clm_mimic_cxr_train.arrow
| +--clm_mimic_cxr_val.arrow
| +--clm_mimic_cxr_test.arrow
| +--nli_radnli_plus_train.arrow
| +--nli_radnli_plus_val.arrow
| +--nli_radnli_plus_test.arrow
Now you can start to fine-tune the ptunifier model:
bash run_scripts/
Supported Tasks:
- Uni-modal Tasks
- Multi-label Classification on CheXpert
- Classification on RNAS Pneumonia
- Classification on RadNLI
- Radiology Report Summarization on MIMIC-CXR
- Cross-modal Tasks
- Cross-modal Retrieval on ROCO (Zero-shot)
- Cross-modal Retrieval on ROCO (Fine-tuned)
- Report Report Generation on MIMIC-CXR
- Multi-modal Tasks
- Visual Question Answering on VQA-RAD
- Visual Question Answering on SLACK
- Visual Question Answering on MedVQA-2019
- Multi-modal Radiology Report Summarization on MIMIC-CXR
Add the hyper-parameters in ptunifier/
Add a new head in ptunifier/modules/
Add a new objective in ptunifier/modules/
Add new metrics and logging scheme in ptunifier/modules/
Add the new prediction heads to the optimizer for lr multiplier in ptunifier/modules/
The code is based on ViLT, METER and MAE. We thank the authors for their open-sourced code and encourage users to cite their works when applicable.