/nam_jax

Jax-based implementation of Neural Additive Models

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

Neural Additive Models in JAX

This repo contains JAX-based version of the model introduced in Neural Additive Models: Interpretable Machine Learning with Neural Nets by R. Agarwal et.al 2021.

NAM Architecture

Dependencies

  • jax
  • optax
  • haiku # used for implementing NN model
  • torch # used for creating mini-batches
  • numpy
  • scikit-learn

Examples

Checkout the nam_regression_example.ipynb notebook to see an example of using the model for the California housing Dataset