EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network
Authors: Neeraj Wagh, Yogatheesan Varatharajah
Affiliation: Department of Bioengineering, University of Illinois at Urbana-Champaign
- ArXiv Pre-print: ** COMING SOON **
- ML4H Poster: ** COMING SOON **
- ML4H Slides: ** COMING SOON **
- Code: GitHub Repo
- Final Models, Pre-computed Features, Training Metadata: FigShare .zip
- Raw Data: MPI LEMON (no registration needed), TUH EEG Abnormal Corpus (needs registration)
This repository contains code to load final models and reproduce held-out test set results reported in Table 2 of the ML4H paper. All the code required to run an experiment is contained entirely inside the corresponding experiment folder. The full end-to-end framework that builds predictive models for scalp-EEG signal classification tasks is work in progress, and will be released as a separate repository when complete.
- Download 1) the pre-computed feature arrays for all 10-second windows in the dataset (power spectral density features, geodesic electrode distances, spectral coherence values), 2) final models used in Table 2 comparisons, and 3) training metadata (which window maps to which subject, target labels, sample indices, etc.) from FigShare
- Place all the feature files and relevant model files inside the directory of the experiment you want to execute. The code expects these files to be present in the experiment's root folder.
- Ensure your execution environment has the following Python dependencies installed and working:
- Python 3.x
- PyTorch (at least 1.4.0 for PyTorch Geometric to work)
- PyTorch Geometric
- Scikit-learn
- Enter the directory of the experiment you want to run.
- Execute
$ python heldout_test_run.py
to run the saved 10 final models on the held-out 30% test set of subjects using the pre-computed features. For trivial classifiers, run$ python chance_level_classification.py
. The mean and standard deviation values reported in Table 2 of the paper will be printed at the end of execution. See notes below for more details.
- All experiments were run and results were reported for a fixed seed (42) in the entire pipeline. We have not repeated the experiments multiple times using different seeds due to time constraints. The seed determines 1) which subjects (out of the total available in the pooled dataset) are held-out for final testing, and 2) which subjects form the 10 train/validation folds within 10-fold cross-validation. Therefore, to reproduce reported results exactly, you will need to use seed as 42 since this will ensure evaluation is done on the subjects that were not seen during training of the released final models.
- We encourage you to 1) closely inspect the process flow depicted in Figure 4 to fully understand the evaluation setup, and 2) train new models using different seeds.
- The EEG-GCNN model definition can be found in EEGGraphConvNet.py in the shallow/deep EEG-GCNN experiment folders.
- The trivial classifiers are not trained on data, and only rely on the label/class imabalance information to provide chance-based predictions. To switch between the two trivial models reported in the paper, change the "MODEL_TYPE" variable in the script.
- In the code, subject-level predictions and metrics (as opposed to window-level) are held in the "patient" variables. For purposes of the code, "patient" = "subject", irrespective of whether the subject/s are healthy or diseased.
- For the FCNN experiment, the model is contained in the EEGConvNet.py file (although it is not a convolutional network). Changing the class/file name would make the saved models unusable and hence has been left unchanged.
- Raw channels are converted to an 8-channel bipolar montage before being used for modeling. While these 8 channels were derived in the 10-20 system, the spatial connectivity between these montage channels is calculated between the electrodes in the center of the montage channel pair in the 10-10 system. The idealized locations of the scalp electrodes in the 10-10 configuration system are taken from standard_1010.tsv file, and used in get_sensor_distances() in EEGGraphDataset.py.
- Issues regarding non-reproducibility of results or support with the codebase should be emailed to nwagh2@illinois.edu
- Neeraj: nwagh2@illinois.edu / Website / Twitter / Google Scholar
- Yoga: varatha2@illinois.edu / Website / Google Scholar
N. Wagh, Y. Varatharajah. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. Proceedings of the ML4H Workshop, NeurIPS Conference 2020