The Hierarchical Label-wise Attention Network was introduced in this this paper. The implementation provided by the author of the paper relies on old version of Tensorflow (v1) and is 3000 lines of code. In this project, I implemented the described HLAN architecture using Pytorch, improving readability and brevity of the code (90% shorter code base). I also applied this model on patient discharge notes and predict the ICD labels (name of diagnoses), achieving the same accuracy as reported in the paper. Other improvements include:
- Improved logging using Weights and Biases wandb library
- Improved interpretability by using the general method provided by captum library
datasets link to download embeddings link to download
The folder structure should be something like this:
hlan-mimic-project/
cache_vocabulary_label_pik/
datasets/
embeddings/
HLAN_pytorch/
conda create -n hlan python=3.11
conda activate hlan
pip install -r requirements.txt
Make sure to be inside HLAN_pytorch folder when triggering training.
mini dataset has 10 data points. Use this option when testing the training script as it is fast to see training complete.
cd HLAN_pytorch
python train.py --epochs 5 --mini
cd HLAN_pytorch
python train.py --epochs 5 --log --verbose
This requires a Weights and Biases account. Follow the steps here to create an account and login using Terminal.
Training loss and metrics will be logged to Weights and Biases when setting --log
use --verbose
if want to print every epoch results to the console
By default, we save the best and last checkpoint to ./checkpoints folder during training. To resume training with that checkpoint, pass the folder name to train.py
cd HLAN_pytorch
python train.py --epochs 5 --log --checkpoint_to_resume_from ../checkpoints/20231127_1604_qxDMJ/last.pt
To train with GPU, use --device gpu
To train with Mac GPU MPS, use --device mps
cd HLAN_pytorch
python train.py --epochs 5 --device mps --mini
To train Label Embedding initialization --le
cd HLAN_pytorch
python train.py --epochs 5 --mini --le
To train without, simply omit .
The paper argued that HLAN's major advantage over other architecture is it allows to attribute the prediction of each label to particular word in the discharge notes. We have attempted to interpret label prediction decision of the model using captum python package. See HLAN_pytorch/visualization_playground.ipynb
for the code and outputs