/echognn

Graph Neural Networks for Explainable Ejection Fraction Estimation

Primary LanguagePythonMIT LicenseMIT

EchoGNN: Explainable Ejection Fraction Estimation with Graph Neural Networks

Official PyTorch implementation for:

Masoud Mokhtari, Teresa Tsang, Purang Abolmaesumi, and Renjie Liao, EchoGNN: Explainable Ejection Fraction Estimation with Graph Neural Networks (MICCAI 2022)

Abstract

Ejection fraction (EF) is a key indicator of cardiac function, allowing identification of patients prone to heart dysfunctions such as heart failure. EF is estimated from cardiac ultrasound videos known as echocardiograms (echo) by manually tracing the left ventricle and estimating its volume on certain frames. These estimations exhibit high inter-observer variability due to the manual process and varying video quality. Such sources of inaccuracy and the need for rapid assessment necessitate the need for reliable and explainable machine learning techniques. In this work, we introduce EchoGNN, a model based on graph neural networks (GNNs) to estimate EF from echo videos. Our model first infers a latent echo-graph from the frames of one or multiple echo cine series. It then estimates weights over nodes and edges of this graph, indicating the importance of individual frames that aid EF estimation. A GNN regressor uses this weighted graph to predict EF. We show, qualitatively and quantitatively, that the learned graph weights provide explainability through identification of critical frames for EF estimation, which can be used to determine when human intervention is required. On EchoNet-Dynamic public EF dataset, EchoGNN achieves EF prediction performance that is on par with state of the art and provides explainability, which is crucial given the high inter-observer variability inherent in this task.

EchoGNN overall architecture

Reproducing MICCAI 2022 Results

To reproduce the exact results reported in our MICCAI 2022 submission, follow the steps below:

  1. Install the required packages by following the instructions in Section Requirements
  2. Download the dataset as mentioned in Section Dataset.
  3. Use the instructions in Section Preprocessing to preprocess EchoNet's CSV file.
  4. In the default.yaml config file provided in the configs/ directory, add the path to your dataset to dataset.dataset_path
  5. Please ensure that in the config file, the path to the trained model is specified as ./trained_models/miccai2022.pth (under model.checkpoint_path)
  6. Run the following command:
python run.py --config_path ./configs/default.yaml --test

Requirements

Use Conda 22, Python3.9, and Pip 22. For Cuda 11.4, run the following commands:

pip3 install torch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu114
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cu102.html

To install the requirements (preferably in a virtual env), run the following command:

pip install -U -r requirements.txt

The below optional packages may also be needed based on what is specified in the config file:

pip install wandb
pip install prettytable

Dataset

EchoNet-Dynamic public EF dataset is used for this project. This dataset can be accessed here. Feel free to download the dataset in any convenient location and provide its directory in the config file as described in Section Config File

Preprocessing

Since ES/ED frame locations are used to assess model's performance, we need to have this information in the CSV file used by the dataset. To add columns for the frame locations to the originally provided FileList.csv, the echonet_preprocess_csc.py script located under /scripts can be used:

python echonet_preprocess_csv.py --data_csv_path <path_to_FileList.csv> --tracing_csv_path <path_to_VolumeTracings.csv> --output_dir <path_to_output_csv>

The original FileList.csv file can now be replaced by the output csv file generated by the script.

Pretraining

To pretrain the model with the task of finding ED/ES frame locations: create a pretraining configuration yaml file similar to /configs/pretrain_default.yaml and run the following command:

python run.py --config_path <path_to_pretraining_config> --save_dir <path_to_save_models_to> --pretrain

Training

To train the model (training + validation): first, create a training configuration yaml file similar to /configs/default.yaml. second, if you desire to use pretrained models obtained by following the instructions in Pretraining, use the pretrained_path option in the config file to specify the path to pretrained models. lastly, run the following command:

python run.py --config_path <path_to_training_config> --save_dir <path_to_save_models_to>

Evaluation

To evaluate an already trained model: first, create a training configuration yaml file similar to /configs/default.yaml that matches the specifications of the trained models. Second, provide the path to trained models using the checkpoint_path option in the config file. Lasly, run the following command:

python run.py --config_path <path_to_training_config> --test

Config File

The default configuration can be found in ./configs/default.yaml. A summary of some important configuration options are provided below:

  • dataset
    • dataset_path: Provide the path to downloaded dataset
    • num_clips_per_vid: Number of random clips to extract and average during training time
    • zoom_aug: Indicates whether augmentation is used during training
  • train
    • criteria
      • classification
        • lambda: Indicates the weight given to classification loss during training
      • sparsity
        • node_lambda: Indicates the weight given to node sparsity loss during training
        • edge_lambda: Indicates the weight given to edge sparsity loss during training
  • model
    • checkpoint_path: path to saved model to use for inference
    • pretrained_path: path to pretrained model