/covid-ehr-benchmarks

A Comprehensive Benchmark For COVID-19 Predictive Modeling Using Electronic Health Records

Primary LanguagePythonGNU General Public License v2.0GPL-2.0

COVID-19 EHR Benchmarks

A Comprehensive Benchmark For COVID-19 Predictive Modeling Using Electronic Health Records

cover

TJH datasets and presentation slides are available in GitHub releases.

We are fixing some minor issues in the data processing pipeline. The GitHub codes are correct, but the arXiv paper is currently outdated. We will upload new version soon.

Prediction Tasks

  • (Early) Mortality outcome prediction
  • Length-of-stay prediction
  • Multi-task/Two-stage prediction

Model Zoo

Machine Learning Models

  • Random forest (RF)
  • Decision tree (DT)
  • Gradient Boosting Decision Tree (GBDT)
  • XGBoost
  • CatBoost

Deep Learning Models

  • Multi-layer perceptron (MLP)
  • Recurrent neural network (RNN)
  • Long-short term memory network (LSTM)
  • Gated recurrent units (GRU)
  • Temporal convolutional networks
  • Transformer

EHR Predictive Models

  • RETAIN
  • StageNet
  • Dr. Agent
  • AdaCare
  • ConCare
  • GRASP

Code Description

app/
    apis/
        ml_{task}.py # machine learning pipelines
        dl_{task}.py # deep learning pipelines
    core/
        evaluation/ # evaluation metrics
        utils/
    datasets/ # dataset loader scripts
    models/
        backbones/ # feature extractors
        classifiers/ # prediction heads
        losses/ # task related loss functions
        build_model.py # concat backbones and heads
configs/
    _base_/
    # common configs
        datasets/
        # dataset basic info, training epochs and dataset split strategy
            {dataset}.yaml
        db.yaml # database settings (optional)
    {config_name}.yaml # detailed model settings
checkpoints/ # model checkpoints are stored here
datasets/ # raw/processed dataset and pre-process script
main.py # main entry point
requirements.txt # code dependencies

Requirements

  • Python 3.7+
  • PyTorch 1.10+
  • Cuda 10.2+ (If you plan to use GPU)

Note:

  • Most models can be run quickly on CPU.
  • You are required to have a GPU with 12GB memory to run ConCare model on CDSL dataset.
  • TCN model may run much faster on CPU.

Usage

  • Install requirements.

    pip install -r requirements.txt [-i https://pypi.tuna.tsinghua.edu.cn/simple] # [xxx] is optional
  • Download TJH dataset from An interpretable mortality prediction model for COVID-19 patients, unzip and put it in datasets/tongji/raw_data/ folder.

  • Run preprocessing notebook. (You can skip this step if you have already done this in the later training process)

  • (The CDSL dataset is also the same process.) You need to apply for the CDSL dataset if necessary. Covid Data Save Lives Dataset

  • Run following commands to train models.

    python main.py --cfg configs/xxx.yaml [--train] [--cuda CUDA_NUM] [--db]
    # Note:
    # 1) use --train for training, only infererence stage if not
    # 2) If you plan to use CUDA, use --cuda 0/1/2/...
    # 3) If you have configured database settings, you can use --db to upload performance after training to the database.

Data Format

The shape and meaning of the tensor fed to the models are as follows:

  • x.pkl: (N, T, D) tensor, where N is the number of patients, T is the number of time steps, and D is the number of features. At $D$ dimention, the first $x$ features are demographic features, the next $y$ features are lab test features, where $x + y = D$
  • y.pkl: (N, T, 2) tensor, where the 2 values are [outcome, length-of-stay] for each time step.
  • visits_length.pkl: (N, ) tensor, where the value is the number of visits for each patient.
  • missing_mask.pkl: same shape as x.pkl, tell whether features are imputed. 1: existing, 0: missing.

Pre-processed data are stored in datasets/{dataset}/processed_data/ folder.

Database preparation [Optional]

Example db.yaml settings, put it in configs/_base_/db.yaml.

engine: postgresql # or mysql
username: db_user
password: db_password
host: xx.xxx.com
port: 5432
database: db_name

Create perflog table in your database:

-- postgresql example
create table perflog
(
	id serial
		constraint perflog_pk
			primary key,
	record_time integer,
	model_name text,
	performance text,
	hidden_dim integer,
	dataset text,
	model_type text,
	config text,
	task text
);

-- mysql example
create table perflog
(
	id int auto_increment,
	record_time int null,
	model_name text null,
	task text null,
	performance text null,
	hidden_dim int null,
	dataset text null,
	model_type text null,
	config text null,
	constraint perflog_id_uindex
		unique (id)
);

alter table perflog
	add primary key (id);

Configs

Below is the configurations after hyperparameter selection.

ML models
hm_los_catboost_kf10_md6_iter150_lr0.1_test
hm_los_decision_tree_kf10_md10_test
hm_los_gbdt_kf10_lr0.1_ss0.8_ne100_test
hm_los_random_forest_kf10_md10_mss2_ne100_test
hm_los_xgboost_kf10_lr0.01_md5_cw3_test
hm_outcome_catboost_kf10_md3_iter150_lr0.1_test
hm_outcome_decision_tree_kf10_md10_test
hm_outcome_gbdt_kf10_lr0.1_ss0.6_ne100_test
hm_outcome_random_forest_kf10_md20_mss10_ne100_test
hm_outcome_xgboost_kf10_lr0.1_md7_cw3_test
tj_los_catboost_kf10_md3_iter150_lr0.1_test
tj_los_decision_tree_kf10_md10_test
tj_los_gbdt_kf10_lr0.1_ss0.8_ne100_test
tj_los_random_forest_kf10_md20_mss5_ne100_test
tj_los_xgboost_kf10_lr0.01_md5_cw1_test
tj_outcome_catboost_kf10_md3_iter150_lr0.1_test
tj_outcome_decision_tree_kf10_md10_test
tj_outcome_gbdt_kf10_lr0.1_ss0.6_ne100_test
tj_outcome_random_forest_kf10_md20_mss2_ne10_test
tj_outcome_xgboost_kf10_lr0.1_md5_cw5_test
DL/EHR models
tj_outcome_grasp_ep100_kf10_bs64_hid64
tj_los_grasp_ep100_kf10_bs64_hid128
tj_outcome_concare_ep100_kf10_bs64_hid128
tj_los_concare_ep100_kf10_bs64_hid128
tj_outcome_agent_ep100_kf10_bs64_hid128
tj_los_agent_ep100_kf10_bs64_hid64
tj_outcome_adacare_ep100_kf10_bs64_hid64
tj_los_adacare_ep100_kf10_bs64_hid64
tj_outcome_transformer_ep100_kf10_bs64_hid128
tj_los_transformer_ep100_kf10_bs64_hid64
tj_outcome_tcn_ep100_kf10_bs64_hid128
tj_los_tcn_ep100_kf10_bs64_hid128
tj_outcome_stagenet_ep100_kf10_bs64_hid64
tj_los_stagenet_ep100_kf10_bs64_hid64
tj_outcome_rnn_ep100_kf10_bs64_hid64
tj_los_rnn_ep100_kf10_bs64_hid128
tj_outcome_retain_ep100_kf10_bs64_hid128
tj_los_retain_ep100_kf10_bs64_hid128
tj_outcome_mlp_ep100_kf10_bs64_hid64
tj_los_mlp_ep100_kf10_bs64_hid128
tj_outcome_lstm_ep100_kf10_bs64_hid64
tj_los_lstm_ep100_kf10_bs64_hid128
tj_outcome_gru_ep100_kf10_bs64_hid64
tj_los_gru_ep100_kf10_bs64_hid128
tj_multitask_rnn_ep100_kf10_bs64_hid64
tj_multitask_lstm_ep100_kf10_bs64_hid128
tj_multitask_gru_ep100_kf10_bs64_hid128
tj_multitask_transformer_ep100_kf10_bs64_hid128
tj_multitask_tcn_ep100_kf10_bs64_hid64
tj_multitask_mlp_ep100_kf10_bs64_hid128
tj_multitask_adacare_ep100_kf10_bs64_hid128
tj_multitask_agent_ep100_kf10_bs64_hid64
tj_multitask_concare_ep100_kf10_bs64_hid128
tj_multitask_stagenet_ep100_kf10_bs64_hid64
tj_multitask_grasp_ep100_kf10_bs64_hid128
tj_multitask_retain_ep100_kf10_bs64_hid64
hm_outcome_mlp_ep100_kf10_bs64_hid64
hm_los_mlp_ep100_kf10_bs64_hid128
hm_outcome_lstm_ep100_kf10_bs64_hid64
hm_los_lstm_ep100_kf10_bs64_hid128
hm_outcome_gru_ep100_kf10_bs64_hid64
hm_los_gru_ep100_kf10_bs64_hid128
hm_outcome_grasp_ep100_kf10_bs64_hid64
hm_los_grasp_ep100_kf10_bs64_hid64
hm_outcome_concare_ep100_kf10_bs64_hid128
hm_los_concare_ep100_kf10_bs64_hid64
hm_outcome_agent_ep100_kf10_bs64_hid128
hm_los_agent_ep100_kf10_bs64_hid64
hm_outcome_adacare_ep100_kf10_bs64_hid64
hm_los_adacare_ep100_kf10_bs64_hid128
hm_outcome_transformer_ep100_kf10_bs64_hid128
hm_los_transformer_ep100_kf10_bs64_hid128
hm_outcome_tcn_ep100_kf10_bs64_hid64
hm_los_tcn_ep100_kf10_bs64_hid128
hm_outcome_stagenet_ep100_kf10_bs64_hid64
hm_los_stagenet_ep100_kf10_bs64_hid64
hm_outcome_rnn_ep100_kf10_bs64_hid64
hm_los_rnn_ep100_kf10_bs64_hid128
hm_outcome_retain_ep100_kf10_bs64_hid128
hm_los_retain_ep100_kf10_bs64_hid128
hm_multitask_rnn_ep100_kf10_bs512_hid128
hm_multitask_lstm_ep100_kf10_bs512_hid64
hm_multitask_gru_ep100_kf10_bs512_hid128
hm_multitask_transformer_ep100_kf10_bs512_hid64
hm_multitask_tcn_ep100_kf10_bs512_hid64
hm_multitask_mlp_ep100_kf10_bs512_hid128
hm_multitask_adacare_ep100_kf10_bs512_hid128
hm_multitask_agent_ep100_kf10_bs512_hid128
hm_multitask_concare_ep100_kf10_bs64_hid128
hm_multitask_stagenet_ep100_kf10_bs512_hid128
hm_multitask_grasp_ep100_kf10_bs512_hid64
hm_multitask_retain_ep100_kf10_bs512_hid128
Two stage configs
tj_twostage_adacare_kf10.yaml
tj_twostage_agent_kf10.yaml
tj_twostage_concare_kf10.yaml
tj_twostage_gru_kf10.yaml
tj_twostage_lstm_kf10.yaml
tj_twostage_mlp_kf10.yaml
tj_twostage_retain_kf10.yaml
tj_twostage_rnn_kf10.yaml
tj_twostage_stagenet_kf10.yaml
tj_twostage_tcn_kf10.yaml
tj_twostage_transformer_kf10.yaml
tj_twostage_grasp_kf10.yaml
hm_twostage_adacare_kf10.yaml
hm_twostage_agent_kf10.yaml
hm_twostage_concare_kf10.yaml
hm_twostage_gru_kf10.yaml
hm_twostage_lstm_kf10.yaml
hm_twostage_mlp_kf10.yaml
hm_twostage_retain_kf10.yaml
hm_twostage_rnn_kf10.yaml
hm_twostage_stagenet_kf10.yaml
hm_twostage_tcn_kf10.yaml
hm_twostage_transformer_kf10.yaml
hm_twostage_grasp_kf10.yaml

Contributing

We appreciate all contributions to improve covid-emr-benchmarks. Pull Requests amd Issues are welcomed!

Contributors

Yinghao Zhu, Wenqing Wang, Junyi Gao

Citation

If you find this project useful in your research, please consider cite:

@misc{https://doi.org/10.48550/arxiv.2209.07805,
  doi = {10.48550/ARXIV.2209.07805},
  url = {https://arxiv.org/abs/2209.07805},
  author = {Gao, Junyi and Zhu, Yinghao and Wang, Wenqing and Wang, Yasha and Tang, Wen and Ma, Liantao},
  keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {A Comprehensive Benchmark for COVID-19 Predictive Modeling Using Electronic Health Records in Intensive Care: Choosing the Best Model for COVID-19 Prognosis},
  publisher = {arXiv},
  year = {2022},
  copyright = {arXiv.org perpetual, non-exclusive license}
}

License

This project is released under the GPL-2.0 license.