The source code for Dr. Agent: Clinical Predictive Model via Mimicked Second Opinions
- Install python, pytorch. We use Python 3.7.3, Pytorch 1.1.
- If you plan to use GPU computation, install CUDA
We provide the trained weights in ./saved_weights/
. You can obtain the reported performance in our paper by simply load the weights to the model by using following codes:
checkpoint = torch.load('./saved_weights/TASK_TO_TEST')
save_chunk = checkpoint['chunk']
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
We do not provide the MIMIC-III data itself. You must acquire the data yourself from https://mimic.physionet.org/. Specifically, download the CSVs. To run MIMIC-III bechmark tasks, you should first build benchmark dataset according to https://github.com/YerevaNN/mimic3-benchmarks/.
After building the benchmark dataset, there will be a directory data/{task}
for each created benchmark task. Then run extract_demo.py
to extract demographics from the dataset (change TASK
to specific task).
-
You can train Dr. Agent on different tasks by running corresponding files.
-
The minimum input you need to run Dr. Agent is the dataset directory and the model save directory
$ python train_decomp.py --data_path='./decompensation/data/' --save_path='./saved_weights/'
-
You can specify batch size
--batch_size <integer>
, learning rate--lr <float>
and epochs--epochs <integer>
-
Additional hyper-parameters can be specified such as the dimension of RNN, using LSTM or GRU, etc. Detailed information can be accessed by
$ python train_decomp.py --help
-
When training is complete, it will output the performance of Dr. Agent on test dataset.
The minimal inputs to Dr. Agent model should contain:
- EHR records (batch_size, time_step, feature_num): The EHR records for a mini-batch of patients.
- Masks (batch_size, time_step): Since all patients' records are padding to the same length to form batches, masks are binary values indicating whether current timestep is a padding value or real value.
- Demographics (batch_size, demo_features): Demographic features of patients. If demographics are not applicable for your dataset, you should use zeros.
You can directly use the model structure in ./model/
directory for different proposes:
model_decomp.py
: binary classification with outputs at each timestepmodel_los.py
: multi-label predictionmodel_mortality.py
: binary classification with output at the last timestepmodel_phenotyping.py
: multi-class prediction
You can also modify the structure for you specific tasks.
Junyi Gao, Cao Xiao, Lucas M Glass, Jimeng Sun,
Dr. Agent: Clinical predictive model via mimicked second opinions,
Journal of the American Medical Informatics Association, ocaa074, https://doi.org/10.1093/jamia/ocaa074