ProtoGate: Prototype-based Neural Networks with Global-to-local Feature Selection for Tabular Biomedical Data (ICML 2024)
Official PyTorch implementation of "ProtoGate: Prototype-based Neural Networks with Global-to-local Feature Selection for Tabular Biomedical Data (ICML 2024)"
by Xiangjian Jiang, Andrei Margeloiu, Nikola Simidjievski, Mateja Jamnik.
The below settings have been tested to work on: CUDA Version: 11.2 (NVIDIA-SMI 460.32.03).
<!-- Install the codebase -->
cd REPOSITORY
conda create python=3.9.0 --name protogate
conda activate protogate
pip install -r requirements.txt
<!-- Optionally, install lightgbm -->
# pip install lightgbm --install-option=--gpu --install-option="--opencl-include-dir=/usr/local/cuda/include/" --install-option="--opencl-library=/usr/local/cuda/lib64/libOpenCL.so"
If the environmental settings cannot work on the local machine, we also provide environmental dependecies for Google Colab. Please follow these steps:
-
Upload the associate notebook
protogate_colab.ipynb
to Google Colab -
Upload the associate codebase to Google Drive
-
Set the
project_path
in notebook to the path of uploaded codebase in Google Drive# Set up the path of codebase project_path = '/path/to/codebase/'
-
Execute the notebook cells in
Step1: Set up environment
We provide scripts for ProtoGate and benchmark methods to work on the real-world (e.g., the “meta-pam” dataset) and synthetic datasets. Below are three examples:
-
ProtoGate on
meta-pam
datasetbash scripts/PROTOGATE/run_exp_protogate_real.sh
-
ProtoGate on
Syn1
datasetbash scripts/PROTOGATE/run_exp_protogate_syn.sh
-
The hyperparameters can be changed in the script by passing different values.
python src/run_experiment.py \ --model 'protogate' \ --dataset 'metabric-pam50__200' \ --metric_model_selection total_loss \ --lr 0.1 \ --protogate_lam_global 0.0002 \ --protogate_lam_local 0.001 \ --pred_k 3 \ --max_steps 8000 \ --protogate_gating_hidden_layer_list 200 \ --tags 'real-world' \ --disable_wandb
-
Q: Where to find other real-world datasets?
A: The other HDLSS datasets can be downloaded from the source website, open-source project; and the non-HDLSS datasets can be downloaded from the TabZilla benchmark.
-
Q: How to get the full log of training and evaluation process?
A: Please install
wandb
library and cancel out--disable_wandb
when running the experiments. -
Q: How to set up the path of files in Google Drive?
A: The only special step is to mount the Google Drive to Google Colab, and the following steps are the same as the local machine.
-
Q: Which license does this codebase follow?
A: This codebase will follow the Apache-2.0 license when ProtoGate is publicly available for community.
Please cite our work if this repo helps you:
@inproceedings{jiang2024protogate,
title={ProtoGate: Prototype-based Neural Networks with Global-to-local Feature Selection for Tabular Biomedical Data},
author={Jiang, Xiangjian and Margeloiu, Andrei and Simidjievski, Nikola and Jamnik, Mateja},
booktitle={Forty-first International Conference on Machine Learning},
year={2024}
}