/ProtoGate

Official PyTorch implementation of “ProtoGate: Prototype-based Neural Networks with Global-to-local Feature Selection for Tabular Biomedical Data” (ICML 2024).

Primary LanguagePythonApache License 2.0Apache-2.0

ProtoGate: Prototype-based Neural Networks with Global-to-local Feature Selection for Tabular Biomedical Data (ICML 2024)

Arxiv-Paper Poster Video presentation License Python 3.9+

Official PyTorch implementation of "ProtoGate: Prototype-based Neural Networks with Global-to-local Feature Selection for Tabular Biomedical Data (ICML 2024)"

by Xiangjian Jiang, Andrei Margeloiu, Nikola Simidjievski, Mateja Jamnik.

ICML2024_ProtoGate_poster

Installing

Local Machine

The below settings have been tested to work on: CUDA Version: 11.2 (NVIDIA-SMI 460.32.03).

<!-- Install the codebase -->
cd REPOSITORY
conda create python=3.9.0 --name protogate
conda activate protogate
pip install -r requirements.txt

<!-- Optionally, install lightgbm -->
# pip install lightgbm --install-option=--gpu --install-option="--opencl-include-dir=/usr/local/cuda/include/" --install-option="--opencl-library=/usr/local/cuda/lib64/libOpenCL.so"

Google Colab

If the environmental settings cannot work on the local machine, we also provide environmental dependecies for Google Colab. Please follow these steps:

  • Upload the associate notebook protogate_colab.ipynb to Google Colab

  • Upload the associate codebase to Google Drive

  • Set the project_path in notebook to the path of uploaded codebase in Google Drive

    # Set up the path of codebase
    project_path = '/path/to/codebase/'
  • Execute the notebook cells in Step1: Set up environment

Running experiments

We provide scripts for ProtoGate and benchmark methods to work on the real-world (e.g., the “meta-pam” dataset) and synthetic datasets. Below are three examples:

  • ProtoGate on meta-pam dataset

    bash scripts/PROTOGATE/run_exp_protogate_real.sh
  • ProtoGate on Syn1 dataset

    bash scripts/PROTOGATE/run_exp_protogate_syn.sh
  • The hyperparameters can be changed in the script by passing different values.

    python src/run_experiment.py \
    	--model 'protogate' \
    	--dataset 'metabric-pam50__200' \
    	--metric_model_selection total_loss \
    	--lr 0.1 \
    	--protogate_lam_global 0.0002 \
    	--protogate_lam_local 0.001 \
    	--pred_k 3 \
    	--max_steps 8000 \
    	--protogate_gating_hidden_layer_list 200 \
    	--tags 'real-world' \
    	--disable_wandb

FAQ

  • Q: Where to find other real-world datasets?

    A: The other HDLSS datasets can be downloaded from the source website, open-source project; and the non-HDLSS datasets can be downloaded from the TabZilla benchmark.

  • Q: How to get the full log of training and evaluation process?

    A: Please install wandb library and cancel out --disable_wandb when running the experiments.

  • Q: How to set up the path of files in Google Drive?

    A: The only special step is to mount the Google Drive to Google Colab, and the following steps are the same as the local machine.

  • Q: Which license does this codebase follow?

    A: This codebase will follow the Apache-2.0 license when ProtoGate is publicly available for community.

Citation

Please cite our work if this repo helps you:

@inproceedings{jiang2024protogate,
  title={ProtoGate: Prototype-based Neural Networks with Global-to-local Feature Selection for Tabular Biomedical Data},
  author={Jiang, Xiangjian and Margeloiu, Andrei and Simidjievski, Nikola and Jamnik, Mateja},
  booktitle={Forty-first International Conference on Machine Learning},
  year={2024}
}