This repo contains the code to run over 1500 experiments that compare the
performance of Deep Learning algorithms for tabular data with LightGBM
.
Deep Learning models for tabular data are run via the pytorch-widedeep library.
Companion post: pytorch-widedeep, deep learning for tabular data IV: Deep Learning vs LightGBM
For the experiments in this repo I have used four datasets:
- Adult Census (binary classification)
- Bank Marketing (binary classification)
- NYC taxi ride duration (regression)
- Facebook Comment Volume (regression)
And mainly four deep learning models:
- TabMlp: a simple MLP very similar to the tabular api implementation in the fastai library
- TabResnet: similar to the MLP but instead of dense layers I use Resnet blocks
- Tabnet
- TabTransformer
ADULT CENSUS
model | acc | runtime | best_epoch_or_ntrees |
---|---|---|---|
lightgbm | 0.878178 | 0.908639 | 408.0 |
tabmlp | 0.872209 | 205.357588 | 62.0 |
tabtransformer | 0.871767 | 288.640581 | 32.0 |
tabnet | 0.870440 | 422.296659 | 26.0 |
tabresnet | 0.869777 | 388.932547 | 25.0 |
BANK MARKETING
model | f1 | auc | runtime | best_epoch_or_ntrees |
---|---|---|---|---|
tabresnet | 0.429799 | 0.650147 | 92.517464 | 11.0 |
tabtransformer | 0.419971 | 0.643972 | 31.693761 | 4.0 |
tabmlp | 0.385542 | 0.628082 | 9.572095 | 7.0 |
lightgbm | 0.385208 | 0.626490 | 0.461398 | 57.0 |
tabnet | 0.308703 | 0.594316 | 77.878060 | 13.0 |
NYC TAXI RIDE DURATION
model | rmse | r2 | runtime | best_epoch_or_ntrees |
---|---|---|---|---|
lightgbm | 262.709865 | 0.804393 | 42.721136 | 504.0 |
tabmlp | 271.342218 | 0.791327 | 568.430923 | 24.0 |
tabresnet | 292.890792 | 0.756867 | 471.264983 | 24.0 |
tabtransformer | 336.582554 | 0.678919 | 5779.031367 | 54.0 |
tabnet | 376.053004 | 0.599198 | 1844.472289 | 15.0 |
FACEBOOK COMMENT VOLUME
model | rmse | r2 | runtime | best_epoch_or_ntrees |
---|---|---|---|---|
lightgbm | 5.528963 | 0.823208 | 6.525877 | 687.0 |
tabmlp | 5.908498 | 0.798103 | 250.476762 | 43.0 |
tabtransformer | 5.925587 | 0.796933 | 533.390816 | 27.0 |
tabresnet | 6.213813 | 0.776698 | 70.466089 | 9.0 |
tabnet | 6.428503 | 0.761001 | 935.020483 | 59.0 |
For more results on all the experiments run see here