/tensorflow-tabnet

Improved TabNet for TensorFlow

Primary LanguagePythonMIT LicenseMIT

TensorFlow TabNet

A TensorFlow 2.X implementation of the paper TabNet.

TabNet

TabNet is already available for TF2 here. However, an error in the batch normalizations of the shared feature transformer blocks makes it unable to learn. In the original code, only the fully connected layers weights are shared, not the batch normalizations which remain unique. When sharing the batch normalizations, training of TabNet on the datasets & hyperparameters of the original paper yields extremely bad results (if able to learn at all).

Moreover, the implementation uses virtual_batch_size from tf.keras.layers.BatchNormalization which has a couple of major issues:

  • Less updates of the batch normalizations with respect to a vanilla Ghost Batch Normalization. Which makes the training unstable and longer.
  • Does not allow batch size which can't be divided by the virtual batch size used during training (even for inference which should be independent of batch size).

Below is a plot of the training accuracy for a model trained with a true Ghost Batch Normalization and one trained with the incorrect virtual_batch_size argument from Keras:

GitHub Logo

Probably not aware of the issues introduced above, the implementation proposes to use Group Normalization instead of Ghost Batch Normalization which does make things better and able to learn. However, no comparision of the results obtained is proposed. Are they even close to the original ones? Can the model really generalize and obtain state of the art results?

Therefore, a new correct and tested TF2 implementation of TabNet is proposed. It not only ports the original code but take advantage of TF2 modular approach to make it easier to finetune and train a TabNet model on different tasks.

Currently in development but can already be used.

Setup

Install

Since this project is still under development, it is best to install it from the repo directly as follow:

pip install git+https://github.com/ostamand/tensorflow-tabnet.git

You can then import & train a classifier using:

from tabnet.modeles.classify import TabNetClassifier

For development

python -m venv venv
source venv/bin/activate
chmod +x scripts/install.sh; ./scripts/install.sh

Dataset

Reference