/Cross-Modal-Transformer

Official repository of cross-modal transformer for interpretable automatic sleep stage classification. https://arxiv.org/abs/2208.06991

Primary LanguageJupyter Notebook

Towards Interpretable Sleep Stage Classification Using Cross-Modal Transformers

Citation

If you find our work or this repository useful, please consider giving a star ⭐ and citation.

@article{pradeepkumar2022towards,
  title={Towards Interpretable Sleep Stage Classification Using Cross-Modal Transformers},
  author={Pradeepkumar, Jathurshan and Anandakumar, Mithunjha and Kugathasan, Vinith and Suntharalingham, Dhinesh and Kappel, Simon L and De Silva, Anjula C and Edussooriya, Chamira US},
  journal={arXiv preprint arXiv:2208.06991},
  year={2022}
}

Abstract

Accurate sleep stage classification is significant for sleep health assessment. In recent years, several deep learning and machine learning based sleep staging algorithms have been developed and they have achieved performance on par with human annotation. Despite improved performance, a limitation of most deep-learning based algorithms is their Black-box behavior, which which have limited their use in clinical settings. Here, we propose Cross-Modal Transformers, which is a transformer-based method for sleep stage classification. Our models achieve both competitive performance with the state-of-the-art approaches and eliminates the Black-box behavior of deep-learning models by utilizing the interpretability aspect of the attention modules. The proposed cross-modal transformers consist of a novel cross-modal transformer encoder architecture along with a multi-scale 1-dimensional convolutional neural network for automatic representation learning. Our sleep stage classifier based on this design was able to achieve sleep stage classification performance on par with or better than the state-of-the-art approaches, along with interpretability, a fourfold reduction in the number of parameters and a reduced training time compared to the current state-of-the-art. This repository contains the implementation of epoch and sequence cross-modal transformers and the interpretations.

Epoch_CMT-1 Seq_CMT-1

Getting Started

Installation Guide

Run our algorithm using Pytorch and CUDA https://pytorch.org/

pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio===0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip install -r requirements.txt

Dataset Generation

Kindly refer to Open In Colab to extract dataset to train the models.

python "./data_preparations/single_epoch.py" --save_path  "path/to/save/dataset"

Inference and Get Interpretations

Kindly refer to Open In Colab to get prediction for any subject in sleepedf dataset and get interpretation results as given in the paper.

Train Cross-Modal Transformers

Train Epoch Cross-Modal Transformer

python cmt_training.py --project_path "./results/<give project name>" --data_path "path/to/dataset" --train_data_list <train dataset fold as a list==> ex:[0,1,2,3]> --val_data_list <validation fold as a list==> ex:[4]> --model_type "Epoch" 

Train Sequence Cross-Modal Transformer

python cmt_training.py --project_path "./results/<give project name>" --data_path "path/to/dataset" --train_data_list <train dataset fold as a list==> ex:[0,1,2,3]> --val_data_list <validation fold as a list==> ex:[4]>  --model_type "Seq" 

Evaluate Cross-Modal Transformers

Get Sleep Staging Results

Evaluate Epoch Cross-Modal Transformer

python cmt_evaluate.py --project_path "./results/<give project name>" --data_path "path/to/dataset" --val_data_list <validation fold as a list==> ex:[4]> --model_type "Epoch" --batch_size 1

Evaluate Sequence Cross-Modal Transformer

python cmt_evaluate.py --project_path "./results/<give project name>" --data_path "path/to/dataset" --val_data_list <validation fold as a list==> ex:[4]> --model_type "Seq" --batch_size 1

Get Results

The interpretation plots will be save under "./results//interpretations/"

python cmt_evaluate.py --project_path "./results/<give project name>" --data_path "path/to/dataset" --val_data_list <validation fold as a list==> ex:[4]> --model_type "Seq" --batch_size 1 --is_interpret True

Sleep Stage Classification Results

param_model_acc-1

Interpretation Results

33320_interpret-1

44001_interpret-1