/PyTorch-SurvNAM

PyTorch implementation of SurvNAM (under development actively)

Primary LanguageJupyter Notebook

PyTorch Implementation of SurvNAM

PyTorch implementation of neural additive models in Neural Additive Models (PyTorch) is adopted for this implementation of SurvNAM.

For neural additive models, check out:

For random survival forests (RSF):

Dependencies

scikit-learn>=1.0.2
numpy>=1.21.5
pandas>=1.3.5
tqdm>=4.54.0
setuptools>=61.2.0

Usage

In Python:

from nam import NeuralAdditiveModel

model = NeuralAdditiveModel(input_size=x_train.shape[-1],
                            shallow_units=100,
                            hidden_units=(64, 32, 32),
                            shallow_layer=ExULayer,
                            hidden_layer=ReLULayer,
                            hidden_dropout=0.1,
                            feature_dropout=0.1)
logits, feature_nn_outputs = model.forward(x)

Citation

If you use this code in your research, please cite the following paper:

SurvNAM

Utkin, L. V., Satyukov, E. D., & Konstantinov, A. V. (2022). SurvNAM: The machine learning survival model explanation. Neural Networks, 147, 81-102.

@article{utkin2022survnam,
    title={SurvNAM: The machine learning survival model explanation},
    author={Utkin, Lev V and Satyukov, Egor D and Konstantinov, Andrei V},
    journal={Neural Networks},
    volume={147},
    pages={81--102},
    year={2022},
    publisher={Elsevier}
}

Neural Additive Models (NAM)

Agarwal, R., Frosst, N., Zhang, X., Caruana, R., & Hinton, G. E. (2020). Neural additive models: Interpretable machine learning with neural nets. arXiv preprint arXiv:2004.13912

@article{agarwal2020neural,
    title={Neural additive models: Interpretable machine learning with neural nets},
    author={Agarwal, Rishabh and Frosst, Nicholas and Zhang, Xuezhou and
    Caruana, Rich and Hinton, Geoffrey E},
    journal={arXiv preprint arXiv:2004.13912},
    year={2020}
}